diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 18649c2c05..c86f783c5b 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -314,3 +314,77 @@ class TestRatelimiter(unittest.HomeserverTestCase):
# Check that we get rate limited after using that token.
self.assertFalse(consume_at(11.1))
+
+ def test_record_action_which_doesnt_fill_bucket(self) -> None:
+ limiter = Ratelimiter(
+ store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
+ )
+
+ # Observe two actions, leaving room in the bucket for one more.
+ limiter.record_action(requester=None, key="a", n_actions=2, _time_now_s=0.0)
+
+ # We should be able to take a new action now.
+ success, _ = self.get_success_or_raise(
+ limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
+ )
+ self.assertTrue(success)
+
+ # ... but not two.
+ success, _ = self.get_success_or_raise(
+ limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
+ )
+ self.assertFalse(success)
+
+ def test_record_action_which_fills_bucket(self) -> None:
+ limiter = Ratelimiter(
+ store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
+ )
+
+ # Observe three actions, filling up the bucket.
+ limiter.record_action(requester=None, key="a", n_actions=3, _time_now_s=0.0)
+
+ # We should be unable to take a new action now.
+ success, _ = self.get_success_or_raise(
+ limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
+ )
+ self.assertFalse(success)
+
+ # If we wait 10 seconds to leak a token, we should be able to take one action...
+ success, _ = self.get_success_or_raise(
+ limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
+ )
+ self.assertTrue(success)
+
+ # ... but not two.
+ success, _ = self.get_success_or_raise(
+ limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
+ )
+ self.assertFalse(success)
+
+ def test_record_action_which_overfills_bucket(self) -> None:
+ limiter = Ratelimiter(
+ store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
+ )
+
+ # Observe four actions, exceeding the bucket.
+ limiter.record_action(requester=None, key="a", n_actions=4, _time_now_s=0.0)
+
+ # We should be prevented from taking a new action now.
+ success, _ = self.get_success_or_raise(
+ limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
+ )
+ self.assertFalse(success)
+
+ # If we wait 10 seconds to leak a token, we should be unable to take an action
+ # because the bucket is still full.
+ success, _ = self.get_success_or_raise(
+ limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
+ )
+ self.assertFalse(success)
+
+ # But after another 10 seconds we leak a second token, giving us room for
+ # action.
+ success, _ = self.get_success_or_raise(
+ limiter.can_do_action(requester=None, key="a", _time_now_s=20.0)
+ )
+ self.assertTrue(success)
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 9f1115dd23..c6dd99316a 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from http import HTTPStatus
from unittest.mock import Mock
from synapse.api.errors import Codes, SynapseError
@@ -50,7 +51,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
complexity = channel.json_body["v1"]
self.assertTrue(complexity > 0, complexity)
@@ -62,7 +63,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
complexity = channel.json_body["v1"]
self.assertEqual(complexity, 1.23)
diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py
index 268a48d7ba..d2bda07198 100644
--- a/tests/federation/test_federation_client.py
+++ b/tests/federation/test_federation_client.py
@@ -45,7 +45,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
# mock up some events to use in the response.
# In real life, these would have things in `prev_events` and `auth_events`, but that's
# a bit annoying to mock up, and the code under test doesn't care, so we don't bother.
- create_event_dict = self.add_hashes_and_signatures(
+ create_event_dict = self.add_hashes_and_signatures_from_other_server(
{
"room_id": test_room_id,
"type": "m.room.create",
@@ -57,7 +57,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
"origin_server_ts": 500,
}
)
- member_event_dict = self.add_hashes_and_signatures(
+ member_event_dict = self.add_hashes_and_signatures_from_other_server(
{
"room_id": test_room_id,
"type": "m.room.member",
@@ -69,7 +69,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
"origin_server_ts": 600,
}
)
- pl_event_dict = self.add_hashes_and_signatures(
+ pl_event_dict = self.add_hashes_and_signatures_from_other_server(
{
"room_id": test_room_id,
"type": "m.room.power_levels",
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 413b3c9426..3a6ef221ae 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from http import HTTPStatus
from parameterized import parameterized
@@ -20,7 +21,6 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.config.server import DEFAULT_ROOM_VERSION
-from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events import make_event_from_dict
from synapse.federation.federation_server import server_matches_acl_event
from synapse.rest import admin
@@ -59,7 +59,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
"/_matrix/federation/v1/get_missing_events/%s" % (room_1,),
query_content,
)
- self.assertEqual(400, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON")
@@ -120,7 +120,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/v1/state/%s?event_id=xyz" % (room_1,)
)
- self.assertEqual(403, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@@ -148,13 +148,13 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
tok2 = self.login("fozzie", "bear")
self.helper.join(self._room_id, second_member_user_id, tok=tok2)
- def _make_join(self, user_id) -> JsonDict:
+ def _make_join(self, user_id: str) -> JsonDict:
channel = self.make_signed_federation_request(
"GET",
f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}"
f"?ver={DEFAULT_ROOM_VERSION}",
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
return channel.json_body
def test_send_join(self):
@@ -163,18 +163,16 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
join_result = self._make_join(joining_user)
join_event_dict = join_result["event"]
- add_hashes_and_signatures(
- KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
+ self.add_hashes_and_signatures_from_other_server(
join_event_dict,
- signature_name=self.OTHER_SERVER_NAME,
- signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
+ KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
)
channel = self.make_signed_federation_request(
"PUT",
f"/_matrix/federation/v2/send_join/{self._room_id}/x",
content=join_event_dict,
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# we should get complete room state back
returned_state = [
@@ -220,18 +218,16 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
join_result = self._make_join(joining_user)
join_event_dict = join_result["event"]
- add_hashes_and_signatures(
- KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
+ self.add_hashes_and_signatures_from_other_server(
join_event_dict,
- signature_name=self.OTHER_SERVER_NAME,
- signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
+ KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
)
channel = self.make_signed_federation_request(
"PUT",
f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true",
content=join_event_dict,
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# expect a reduced room state
returned_state = [
@@ -264,6 +260,67 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
+ @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 3}})
+ def test_make_join_respects_room_join_rate_limit(self) -> None:
+ # In the test setup, two users join the room. Since the rate limiter burst
+ # count is 3, a new make_join request to the room should be accepted.
+
+ joining_user = "@ronniecorbett:" + self.OTHER_SERVER_NAME
+ self._make_join(joining_user)
+
+ # Now have a new local user join the room. This saturates the rate limiter
+ # bucket, so the next make_join should be denied.
+ new_local_user = self.register_user("animal", "animal")
+ token = self.login("animal", "animal")
+ self.helper.join(self._room_id, new_local_user, tok=token)
+
+ joining_user = "@ronniebarker:" + self.OTHER_SERVER_NAME
+ channel = self.make_signed_federation_request(
+ "GET",
+ f"/_matrix/federation/v1/make_join/{self._room_id}/{joining_user}"
+ f"?ver={DEFAULT_ROOM_VERSION}",
+ )
+ self.assertEqual(channel.code, HTTPStatus.TOO_MANY_REQUESTS, channel.json_body)
+
+ @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 3}})
+ def test_send_join_contributes_to_room_join_rate_limit_and_is_limited(self) -> None:
+ # Make two make_join requests up front. (These are rate limited, but do not
+ # contribute to the rate limit.)
+ join_event_dicts = []
+ for i in range(2):
+ joining_user = f"@misspiggy{i}:{self.OTHER_SERVER_NAME}"
+ join_result = self._make_join(joining_user)
+ join_event_dict = join_result["event"]
+ self.add_hashes_and_signatures_from_other_server(
+ join_event_dict,
+ KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
+ )
+ join_event_dicts.append(join_event_dict)
+
+ # In the test setup, two users join the room. Since the rate limiter burst
+ # count is 3, the first send_join should be accepted...
+ channel = self.make_signed_federation_request(
+ "PUT",
+ f"/_matrix/federation/v2/send_join/{self._room_id}/join0",
+ content=join_event_dicts[0],
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # ... but the second should be denied.
+ channel = self.make_signed_federation_request(
+ "PUT",
+ f"/_matrix/federation/v2/send_join/{self._room_id}/join1",
+ content=join_event_dicts[1],
+ )
+ self.assertEqual(channel.code, HTTPStatus.TOO_MANY_REQUESTS, channel.json_body)
+
+ # NB: we could write a test which checks that the send_join event is seen
+ # by other workers over replication, and that they update their rate limit
+ # buckets accordingly. I'm going to assume that the join event gets sent over
+ # replication, at which point the tests.handlers.room_member test
+ # test_local_users_joining_on_another_worker_contribute_to_rate_limit
+ # is probably sufficient to reassure that the bucket is updated.
+
def _create_acl_event(content):
return make_event_from_dict(
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index d21c11b716..0d048207b7 100644
--- a/tests/federation/transport/test_knocking.py
+++ b/tests/federation/transport/test_knocking.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
+from http import HTTPStatus
from typing import Dict, List
from synapse.api.constants import EventTypes, JoinRules, Membership
@@ -255,7 +256,7 @@ class FederationKnockingTestCase(
RoomVersions.V7.identifier,
),
)
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
# Note: We don't expect the knock membership event to be sent over federation as
# part of the stripped room state, as the knocking homeserver already has that
@@ -293,7 +294,7 @@ class FederationKnockingTestCase(
% (room_id, signed_knock_event.event_id),
signed_knock_event_json,
)
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
# Check that we got the stripped room state in return
room_state_events = channel.json_body["knock_state_events"]
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index d96d5aa138..b17af2725b 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -50,7 +50,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_scheduler = Mock()
hs = Mock()
hs.get_datastores.return_value = Mock(main=self.mock_store)
- self.mock_store.get_received_ts.return_value = make_awaitable(0)
+ self.mock_store.get_appservice_last_pos.return_value = make_awaitable(None)
self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None)
self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable(
None
@@ -76,9 +76,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
event = Mock(
sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
)
- self.mock_store.get_new_events_for_appservice.side_effect = [
- make_awaitable((0, [])),
- make_awaitable((1, [event])),
+ self.mock_store.get_all_new_events_stream.side_effect = [
+ make_awaitable((0, [], {})),
+ make_awaitable((1, [event], {event.event_id: 0})),
]
self.handler.notify_interested_services(RoomStreamToken(None, 1))
@@ -95,8 +95,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
self.mock_as_api.query_user.return_value = make_awaitable(True)
- self.mock_store.get_new_events_for_appservice.side_effect = [
- make_awaitable((0, [event])),
+ self.mock_store.get_all_new_events_stream.side_effect = [
+ make_awaitable((0, [event], {event.event_id: 0})),
]
self.handler.notify_interested_services(RoomStreamToken(None, 0))
@@ -112,8 +112,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
self.mock_as_api.query_user.return_value = make_awaitable(True)
- self.mock_store.get_new_events_for_appservice.side_effect = [
- make_awaitable((0, [event])),
+ self.mock_store.get_all_new_events_stream.side_effect = [
+ make_awaitable((0, [event], {event.event_id: 0})),
]
self.handler.notify_interested_services(RoomStreamToken(None, 0))
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 9b9c11fab7..8a0bb91f40 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List, cast
+from typing import cast
from unittest import TestCase
from twisted.test.proto_helpers import MemoryReactor
@@ -50,8 +50,6 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
hs = self.setup_test_homeserver(federation_http_client=None)
self.handler = hs.get_federation_handler()
self.store = hs.get_datastores().main
- self.state_storage_controller = hs.get_storage_controllers().state
- self._event_auth_handler = hs.get_event_auth_handler()
return hs
def test_exchange_revoked_invite(self) -> None:
@@ -256,7 +254,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
]
for _ in range(0, 8):
event = make_event_from_dict(
- self.add_hashes_and_signatures(
+ self.add_hashes_and_signatures_from_other_server(
{
"origin_server_ts": 1,
"type": "m.room.message",
@@ -314,142 +312,6 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
)
self.get_success(d)
- def test_backfill_floating_outlier_membership_auth(self) -> None:
- """
- As the local homeserver, check that we can properly process a federated
- event from the OTHER_SERVER with auth_events that include a floating
- membership event from the OTHER_SERVER.
-
- Regression test, see #10439.
- """
- OTHER_SERVER = "otherserver"
- OTHER_USER = "@otheruser:" + OTHER_SERVER
-
- # create the room
- user_id = self.register_user("kermit", "test")
- tok = self.login("kermit", "test")
- room_id = self.helper.create_room_as(
- room_creator=user_id,
- is_public=True,
- tok=tok,
- extra_content={
- "preset": "public_chat",
- },
- )
- room_version = self.get_success(self.store.get_room_version(room_id))
-
- prev_event_ids = self.get_success(self.store.get_prev_events_for_room(room_id))
- (
- most_recent_prev_event_id,
- most_recent_prev_event_depth,
- ) = self.get_success(self.store.get_max_depth_of(prev_event_ids))
- # mapping from (type, state_key) -> state_event_id
- assert most_recent_prev_event_id is not None
- prev_state_map = self.get_success(
- self.state_storage_controller.get_state_ids_for_event(
- most_recent_prev_event_id
- )
- )
- # List of state event ID's
- prev_state_ids = list(prev_state_map.values())
- auth_event_ids = prev_state_ids
- auth_events = list(
- self.get_success(self.store.get_events(auth_event_ids)).values()
- )
-
- # build a floating outlier member state event
- fake_prev_event_id = "$" + random_string(43)
- member_event_dict = {
- "type": EventTypes.Member,
- "content": {
- "membership": "join",
- },
- "state_key": OTHER_USER,
- "room_id": room_id,
- "sender": OTHER_USER,
- "depth": most_recent_prev_event_depth,
- "prev_events": [fake_prev_event_id],
- "origin_server_ts": self.clock.time_msec(),
- "signatures": {OTHER_SERVER: {"ed25519:key_version": "SomeSignatureHere"}},
- }
- builder = self.hs.get_event_builder_factory().for_room_version(
- room_version, member_event_dict
- )
- member_event = self.get_success(
- builder.build(
- prev_event_ids=member_event_dict["prev_events"],
- auth_event_ids=self._event_auth_handler.compute_auth_events(
- builder,
- prev_state_map,
- for_verification=False,
- ),
- depth=member_event_dict["depth"],
- )
- )
- # Override the signature added from "test" homeserver that we created the event with
- member_event.signatures = member_event_dict["signatures"]
-
- # Add the new member_event to the StateMap
- updated_state_map = dict(prev_state_map)
- updated_state_map[
- (member_event.type, member_event.state_key)
- ] = member_event.event_id
- auth_events.append(member_event)
-
- # build and send an event authed based on the member event
- message_event_dict = {
- "type": EventTypes.Message,
- "content": {},
- "room_id": room_id,
- "sender": OTHER_USER,
- "depth": most_recent_prev_event_depth,
- "prev_events": prev_event_ids.copy(),
- "origin_server_ts": self.clock.time_msec(),
- "signatures": {OTHER_SERVER: {"ed25519:key_version": "SomeSignatureHere"}},
- }
- builder = self.hs.get_event_builder_factory().for_room_version(
- room_version, message_event_dict
- )
- message_event = self.get_success(
- builder.build(
- prev_event_ids=message_event_dict["prev_events"],
- auth_event_ids=self._event_auth_handler.compute_auth_events(
- builder,
- updated_state_map,
- for_verification=False,
- ),
- depth=message_event_dict["depth"],
- )
- )
- # Override the signature added from "test" homeserver that we created the event with
- message_event.signatures = message_event_dict["signatures"]
-
- # Stub the /event_auth response from the OTHER_SERVER
- async def get_event_auth(
- destination: str, room_id: str, event_id: str
- ) -> List[EventBase]:
- return [
- event_from_pdu_json(ae.get_pdu_json(), room_version=room_version)
- for ae in auth_events
- ]
-
- self.handler.federation_client.get_event_auth = get_event_auth # type: ignore[assignment]
-
- with LoggingContext("receive_pdu"):
- # Fake the OTHER_SERVER federating the message event over to our local homeserver
- d = run_in_background(
- self.hs.get_federation_event_handler().on_receive_pdu,
- OTHER_SERVER,
- message_event,
- )
- self.get_success(d)
-
- # Now try and get the events on our local homeserver
- stored_event = self.get_success(
- self.store.get_event(message_event.event_id, allow_none=True)
- )
- self.assertTrue(stored_event is not None)
-
@unittest.override_config(
{"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
)
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 4b1a8f04db..51c8dd6498 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -104,7 +104,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
# mock up a load of state events which we are missing
state_events = [
make_event_from_dict(
- self.add_hashes_and_signatures(
+ self.add_hashes_and_signatures_from_other_server(
{
"type": "test_state_type",
"state_key": f"state_{i}",
@@ -131,7 +131,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
# Depending on the test, we either persist this upfront (as an outlier),
# or let the server request it.
prev_event = make_event_from_dict(
- self.add_hashes_and_signatures(
+ self.add_hashes_and_signatures_from_other_server(
{
"type": "test_regular_type",
"room_id": room_id,
@@ -165,7 +165,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
# mock up a regular event to pass into _process_pulled_event
pulled_event = make_event_from_dict(
- self.add_hashes_and_signatures(
+ self.add_hashes_and_signatures_from_other_server(
{
"type": "test_regular_type",
"room_id": room_id,
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 82b3bb3b73..4c62449c89 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -14,6 +14,7 @@
"""Tests for the password_auth_provider interface"""
+from http import HTTPStatus
from typing import Any, Type, Union
from unittest.mock import Mock
@@ -188,14 +189,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# check_password must return an awaitable
mock_password_provider.check_password.return_value = make_awaitable(True)
channel = self._send_password_login("u", "p")
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@u:test", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
mock_password_provider.reset_mock()
# login with mxid should work too
channel = self._send_password_login("@u:bz", "p")
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@u:bz", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
mock_password_provider.reset_mock()
@@ -204,7 +205,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# in these cases, but at least we can guard against the API changing
# unexpectedly
channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with(
"@ USER🙂NAME :test", " pASS😢word "
@@ -258,10 +259,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# check_password must return an awaitable
mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._send_password_login("u", "p")
- self.assertEqual(channel.code, 403, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
channel = self._send_password_login("localuser", "localpass")
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@localuser:test", channel.json_body["user_id"])
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
@@ -382,7 +383,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("u", "p")
- self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_password.assert_not_called()
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
@@ -406,14 +407,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login with missing param should be rejected
channel = self._send_login("test.login_type", "u")
- self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", None)
)
channel = self._send_login("test.login_type", "u", test_field="y")
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"}
@@ -427,7 +428,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
("@ MALFORMED! :bz", None)
)
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
" USER🙂NAME ", "test.login_type", {"test_field": " abc "}
@@ -510,7 +511,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
("@user:bz", callback)
)
channel = self._send_login("test.login_type", "u", test_field="y")
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"}
@@ -549,7 +550,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
- self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()
@override_config(
@@ -584,7 +585,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
- self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()
@override_config(
@@ -615,7 +616,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
- self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_password.assert_not_called()
@@ -646,13 +647,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
("@localuser:test", None)
)
channel = self._send_login("test.login_type", "localuser", test_field="")
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
tok1 = channel.json_body["access_token"]
channel = self._send_login(
"test.login_type", "localuser", test_field="", device_id="dev2"
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
# make the initial request which returns a 401
channel = self._delete_device(tok1, "dev2")
@@ -721,7 +722,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# password login shouldn't work and should be rejected with a 400
# ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
- self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
def test_on_logged_out(self):
"""Tests that the on_logged_out callback is called when the user logs out."""
@@ -884,7 +885,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
},
access_token=tok,
)
- self.assertEqual(channel.code, 403, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
self.assertEqual(
channel.json_body["errcode"],
Codes.THREEPID_DENIED,
@@ -906,7 +907,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
},
access_token=tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertIn("sid", channel.json_body)
m.assert_called_once_with("email", "bar@test.com", registration)
@@ -949,12 +950,12 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"register",
{"auth": {"session": session, "type": LoginType.DUMMY}},
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
return channel.json_body
def _get_login_flows(self) -> JsonDict:
channel = self.make_request("GET", "/_matrix/client/r0/login")
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
return channel.json_body["flows"]
def _send_password_login(self, user: str, password: str) -> FakeChannel:
diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py
new file mode 100644
index 0000000000..254e7e4b80
--- /dev/null
+++ b/tests/handlers/test_room_member.py
@@ -0,0 +1,290 @@
+from http import HTTPStatus
+from unittest.mock import Mock, patch
+
+from twisted.test.proto_helpers import MemoryReactor
+
+import synapse.rest.admin
+import synapse.rest.client.login
+import synapse.rest.client.room
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.errors import LimitExceededError
+from synapse.crypto.event_signing import add_hashes_and_signatures
+from synapse.events import FrozenEventV3
+from synapse.federation.federation_client import SendJoinResult
+from synapse.server import HomeServer
+from synapse.types import UserID, create_requester
+from synapse.util import Clock
+
+from tests.replication._base import RedisMultiWorkerStreamTestCase
+from tests.server import make_request
+from tests.test_utils import make_awaitable
+from tests.unittest import FederatingHomeserverTestCase, override_config
+
+
+class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ synapse.rest.client.login.register_servlets,
+ synapse.rest.client.room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.handler = hs.get_room_member_handler()
+
+ # Create three users.
+ self.alice = self.register_user("alice", "pass")
+ self.alice_token = self.login("alice", "pass")
+ self.bob = self.register_user("bob", "pass")
+ self.bob_token = self.login("bob", "pass")
+ self.chris = self.register_user("chris", "pass")
+ self.chris_token = self.login("chris", "pass")
+
+ # Create a room on this homeserver. Note that this counts as a join: it
+ # contributes to the rate limter's count of actions
+ self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token)
+
+ self.intially_unjoined_room_id = f"!example:{self.OTHER_SERVER_NAME}"
+
+ @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}})
+ def test_local_user_local_joins_contribute_to_limit_and_are_limited(self) -> None:
+ # The rate limiter has accumulated one token from Alice's join after the create
+ # event.
+ # Try joining the room as Bob.
+ self.get_success(
+ self.handler.update_membership(
+ requester=create_requester(self.bob),
+ target=UserID.from_string(self.bob),
+ room_id=self.room_id,
+ action=Membership.JOIN,
+ )
+ )
+
+ # The rate limiter bucket is full. A second join should be denied.
+ self.get_failure(
+ self.handler.update_membership(
+ requester=create_requester(self.chris),
+ target=UserID.from_string(self.chris),
+ room_id=self.room_id,
+ action=Membership.JOIN,
+ ),
+ LimitExceededError,
+ )
+
+ @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}})
+ def test_local_user_profile_edits_dont_contribute_to_limit(self) -> None:
+ # The rate limiter has accumulated one token from Alice's join after the create
+ # event. Alice should still be able to change her displayname.
+ self.get_success(
+ self.handler.update_membership(
+ requester=create_requester(self.alice),
+ target=UserID.from_string(self.alice),
+ room_id=self.room_id,
+ action=Membership.JOIN,
+ content={"displayname": "Alice Cooper"},
+ )
+ )
+
+ # Still room in the limiter bucket. Chris's join should be accepted.
+ self.get_success(
+ self.handler.update_membership(
+ requester=create_requester(self.chris),
+ target=UserID.from_string(self.chris),
+ room_id=self.room_id,
+ action=Membership.JOIN,
+ )
+ )
+
+ @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 1}})
+ def test_remote_joins_contribute_to_rate_limit(self) -> None:
+ # Join once, to fill the rate limiter bucket.
+ #
+ # To do this we have to mock the responses from the remote homeserver.
+ # We also patch out a bunch of event checks on our end. All we're really
+ # trying to check here is that remote joins will bump the rate limter when
+ # they are persisted.
+ create_event_source = {
+ "auth_events": [],
+ "content": {
+ "creator": f"@creator:{self.OTHER_SERVER_NAME}",
+ "room_version": self.hs.config.server.default_room_version.identifier,
+ },
+ "depth": 0,
+ "origin_server_ts": 0,
+ "prev_events": [],
+ "room_id": self.intially_unjoined_room_id,
+ "sender": f"@creator:{self.OTHER_SERVER_NAME}",
+ "state_key": "",
+ "type": EventTypes.Create,
+ }
+ self.add_hashes_and_signatures_from_other_server(
+ create_event_source,
+ self.hs.config.server.default_room_version,
+ )
+ create_event = FrozenEventV3(
+ create_event_source,
+ self.hs.config.server.default_room_version,
+ {},
+ None,
+ )
+
+ join_event_source = {
+ "auth_events": [create_event.event_id],
+ "content": {"membership": "join"},
+ "depth": 1,
+ "origin_server_ts": 100,
+ "prev_events": [create_event.event_id],
+ "sender": self.bob,
+ "state_key": self.bob,
+ "room_id": self.intially_unjoined_room_id,
+ "type": EventTypes.Member,
+ }
+ add_hashes_and_signatures(
+ self.hs.config.server.default_room_version,
+ join_event_source,
+ self.hs.hostname,
+ self.hs.signing_key,
+ )
+ join_event = FrozenEventV3(
+ join_event_source,
+ self.hs.config.server.default_room_version,
+ {},
+ None,
+ )
+
+ mock_make_membership_event = Mock(
+ return_value=make_awaitable(
+ (
+ self.OTHER_SERVER_NAME,
+ join_event,
+ self.hs.config.server.default_room_version,
+ )
+ )
+ )
+ mock_send_join = Mock(
+ return_value=make_awaitable(
+ SendJoinResult(
+ join_event,
+ self.OTHER_SERVER_NAME,
+ state=[create_event],
+ auth_chain=[create_event],
+ partial_state=False,
+ servers_in_room=[],
+ )
+ )
+ )
+
+ with patch.object(
+ self.handler.federation_handler.federation_client,
+ "make_membership_event",
+ mock_make_membership_event,
+ ), patch.object(
+ self.handler.federation_handler.federation_client,
+ "send_join",
+ mock_send_join,
+ ), patch(
+ "synapse.event_auth._is_membership_change_allowed",
+ return_value=None,
+ ), patch(
+ "synapse.handlers.federation_event.check_state_dependent_auth_rules",
+ return_value=None,
+ ):
+ self.get_success(
+ self.handler.update_membership(
+ requester=create_requester(self.bob),
+ target=UserID.from_string(self.bob),
+ room_id=self.intially_unjoined_room_id,
+ action=Membership.JOIN,
+ remote_room_hosts=[self.OTHER_SERVER_NAME],
+ )
+ )
+
+ # Try to join as Chris. Should get denied.
+ self.get_failure(
+ self.handler.update_membership(
+ requester=create_requester(self.chris),
+ target=UserID.from_string(self.chris),
+ room_id=self.intially_unjoined_room_id,
+ action=Membership.JOIN,
+ remote_room_hosts=[self.OTHER_SERVER_NAME],
+ ),
+ LimitExceededError,
+ )
+
+ # TODO: test that remote joins to a room are rate limited.
+ # Could do this by setting the burst count to 1, then:
+ # - remote-joining a room
+ # - immediately leaving
+ # - trying to remote-join again.
+
+
+class TestReplicatedJoinsLimitedByPerRoomRateLimiter(RedisMultiWorkerStreamTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ synapse.rest.client.login.register_servlets,
+ synapse.rest.client.room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.handler = hs.get_room_member_handler()
+
+ # Create three users.
+ self.alice = self.register_user("alice", "pass")
+ self.alice_token = self.login("alice", "pass")
+ self.bob = self.register_user("bob", "pass")
+ self.bob_token = self.login("bob", "pass")
+ self.chris = self.register_user("chris", "pass")
+ self.chris_token = self.login("chris", "pass")
+
+ # Create a room on this homeserver.
+ # Note that this counts as a
+ self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token)
+ self.intially_unjoined_room_id = "!example:otherhs"
+
+ @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}})
+ def test_local_users_joining_on_another_worker_contribute_to_rate_limit(
+ self,
+ ) -> None:
+ # The rate limiter has accumulated one token from Alice's join after the create
+ # event.
+ self.replicate()
+
+ # Spawn another worker and have bob join via it.
+ worker_app = self.make_worker_hs(
+ "synapse.app.generic_worker", extra_config={"worker_name": "other worker"}
+ )
+ worker_site = self._hs_to_site[worker_app]
+ channel = make_request(
+ self.reactor,
+ worker_site,
+ "POST",
+ f"/_matrix/client/v3/rooms/{self.room_id}/join",
+ access_token=self.bob_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
+
+ # wait for join to arrive over replication
+ self.replicate()
+
+ # Try to join as Chris on the worker. Should get denied because Alice
+ # and Bob have both joined the room.
+ self.get_failure(
+ worker_app.get_room_member_handler().update_membership(
+ requester=create_requester(self.chris),
+ target=UserID.from_string(self.chris),
+ room_id=self.room_id,
+ action=Membership.JOIN,
+ ),
+ LimitExceededError,
+ )
+
+ # Try to join as Chris on the original worker. Should get denied because Alice
+ # and Bob have both joined the room.
+ self.get_failure(
+ self.handler.update_membership(
+ requester=create_requester(self.chris),
+ target=UserID.from_string(self.chris),
+ room_id=self.room_id,
+ action=Membership.JOIN,
+ ),
+ LimitExceededError,
+ )
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index ecc7cc6461..e3f38fbcc5 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -159,7 +159,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Blow away caches (supported room versions can only change due to a restart).
self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
- self.store._get_event_cache.clear()
+ self.get_success(self.store._get_event_cache.clear())
self.store._event_ref.clear()
# The rooms should be excluded from the sync response.
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 230dc76f72..2526136ff8 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -21,7 +21,7 @@ from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RoomTypes
from synapse.api.errors import Codes
from synapse.handlers.pagination import PaginationHandler
from synapse.rest.client import directory, events, login, room
@@ -1130,6 +1130,8 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertIn("guest_access", r)
self.assertIn("history_visibility", r)
self.assertIn("state_events", r)
+ self.assertIn("room_type", r)
+ self.assertIsNone(r["room_type"])
# Check that the correct number of total rooms was returned
self.assertEqual(channel.json_body["total_rooms"], total_rooms)
@@ -1229,7 +1231,11 @@ class RoomTestCase(unittest.HomeserverTestCase):
def test_correct_room_attributes(self) -> None:
"""Test the correct attributes for a room are returned"""
# Create a test room
- room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id = self.helper.create_room_as(
+ self.admin_user,
+ tok=self.admin_user_tok,
+ extra_content={"creation_content": {"type": RoomTypes.SPACE}},
+ )
test_alias = "#test:test"
test_room_name = "something"
@@ -1306,6 +1312,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(room_id, r["room_id"])
self.assertEqual(test_room_name, r["name"])
self.assertEqual(test_alias, r["canonical_alias"])
+ self.assertEqual(RoomTypes.SPACE, r["room_type"])
def test_room_list_sort_order(self) -> None:
"""Test room list sort ordering. alphabetical name versus number of members,
@@ -1630,7 +1637,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertIn("guest_access", channel.json_body)
self.assertIn("history_visibility", channel.json_body)
self.assertIn("state_events", channel.json_body)
-
+ self.assertIn("room_type", channel.json_body)
self.assertEqual(room_id_1, channel.json_body["room_id"])
def test_single_room_devices(self) -> None:
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index e32aaadb98..12db68d564 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1379,7 +1379,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
- self.assertEqual(201, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1434,7 +1434,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
- self.assertEqual(201, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1512,7 +1512,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "abc123", "admin": False},
)
- self.assertEqual(201, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertFalse(channel.json_body["admin"])
@@ -1550,7 +1550,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
# Admin user is not blocked by mau anymore
- self.assertEqual(201, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertFalse(channel.json_body["admin"])
@@ -1585,7 +1585,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
- self.assertEqual(201, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
@@ -1626,7 +1626,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
- self.assertEqual(201, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
@@ -1636,6 +1636,41 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(len(pushers), 0)
+ @override_config(
+ {
+ "email": {
+ "enable_notifs": True,
+ "notif_for_new_users": True,
+ "notif_from": "test@example.com",
+ },
+ "public_baseurl": "https://example.com",
+ }
+ )
+ def test_create_user_email_notif_for_new_users_with_msisdn_threepid(self) -> None:
+ """
+ Check that a new regular user is created successfully when they have a msisdn
+ threepid and email notif_for_new_users is set to True.
+ """
+ url = self.url_prefix % "@bob:test"
+
+ # Create user
+ body = {
+ "password": "abc123",
+ "threepids": [{"medium": "msisdn", "address": "1234567890"}],
+ }
+
+ channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content=body,
+ )
+
+ self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual("msisdn", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual("1234567890", channel.json_body["threepids"][0]["address"])
+
def test_set_password(self) -> None:
"""
Test setting a new password for another user.
@@ -2372,7 +2407,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "abc123"},
)
- self.assertEqual(201, channel.code, msg=channel.json_body)
+ self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"])
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 1f9b65351e..7ae926dc9c 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import os
import re
from email.parser import Parser
+from http import HTTPStatus
from typing import Any, Dict, List, Optional, Union
from unittest.mock import Mock
@@ -95,10 +95,8 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
"""
body = {"type": "m.login.password", "user": username, "password": password}
- channel = self.make_request(
- "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
- )
- self.assertEqual(channel.code, 403, channel.result)
+ channel = self.make_request("POST", "/_matrix/client/r0/login", body)
+ self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
def test_basic_password_reset(self) -> None:
"""Test basic password reset flow"""
@@ -347,7 +345,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
shorthand=False,
)
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
# Now POST to the same endpoint, mimicking the same behaviour as clicking the
# password reset confirm button
@@ -362,7 +360,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
shorthand=False,
content_is_form=True,
)
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
def _get_link_from_email(self) -> str:
assert self.email_attempts, "No emails have been sent"
@@ -390,7 +388,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
new_password: str,
session_id: str,
client_secret: str,
- expected_code: int = 200,
+ expected_code: int = HTTPStatus.OK,
) -> None:
channel = self.make_request(
"POST",
@@ -479,16 +477,14 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
self.assertEqual(memberships[0].room_id, room_id, memberships)
def deactivate(self, user_id: str, tok: str) -> None:
- request_data = json.dumps(
- {
- "auth": {
- "type": "m.login.password",
- "user": user_id,
- "password": "test",
- },
- "erase": False,
- }
- )
+ request_data = {
+ "auth": {
+ "type": "m.login.password",
+ "user": user_id,
+ "password": "test",
+ },
+ "erase": False,
+ }
channel = self.make_request(
"POST", "account/deactivate", request_data, access_token=tok
)
@@ -715,7 +711,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user
@@ -725,7 +723,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_delete_email(self) -> None:
@@ -747,7 +745,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
{"medium": "email", "address": self.email},
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# Get user
channel = self.make_request(
@@ -756,7 +754,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_delete_email_if_disabled(self) -> None:
@@ -781,7 +779,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user
@@ -791,7 +791,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
@@ -817,7 +817,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user
@@ -827,7 +829,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_no_valid_token(self) -> None:
@@ -852,7 +854,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user
@@ -862,7 +866,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
@override_config({"next_link_domain_whitelist": None})
@@ -872,7 +876,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link="https://example.com/a/good/site",
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
@override_config({"next_link_domain_whitelist": None})
@@ -884,7 +888,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link="some-protocol://abcdefghijklmopqrstuvwxyz",
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
@override_config({"next_link_domain_whitelist": None})
@@ -895,7 +899,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link="file:///host/path",
- expect_code=400,
+ expect_code=HTTPStatus.BAD_REQUEST,
)
@override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
@@ -907,28 +911,28 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link=None,
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
self._request_token(
"something@example.com",
"some_secret",
next_link="https://example.com/some/good/page",
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
self._request_token(
"something@example.com",
"some_secret",
next_link="https://example.org/some/also/good/page",
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
self._request_token(
"something@example.com",
"some_secret",
next_link="https://bad.example.org/some/bad/page",
- expect_code=400,
+ expect_code=HTTPStatus.BAD_REQUEST,
)
@override_config({"next_link_domain_whitelist": []})
@@ -940,7 +944,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link="https://example.com/a/page",
- expect_code=400,
+ expect_code=HTTPStatus.BAD_REQUEST,
)
def _request_token(
@@ -948,7 +952,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
email: str,
client_secret: str,
next_link: Optional[str] = None,
- expect_code: int = 200,
+ expect_code: int = HTTPStatus.OK,
) -> Optional[str]:
"""Request a validation token to add an email address to a user's account
@@ -993,7 +997,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
b"account/3pid/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1},
)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
self.assertEqual(expected_errcode, channel.json_body["errcode"])
self.assertEqual(expected_error, channel.json_body["error"])
@@ -1002,7 +1008,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
path = link.replace("https://example.com", "")
channel = self.make_request("GET", path, shorthand=False)
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
def _get_link_from_email(self) -> str:
assert self.email_attempts, "No emails have been sent"
@@ -1052,7 +1058,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# Get user
channel = self.make_request(
@@ -1061,7 +1067,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
@@ -1092,7 +1098,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
"""Tests that not providing any MXID raises an error."""
self._test_status(
users=None,
- expected_status_code=400,
+ expected_status_code=HTTPStatus.BAD_REQUEST,
expected_errcode=Codes.MISSING_PARAM,
)
@@ -1100,7 +1106,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
"""Tests that providing an invalid MXID raises an error."""
self._test_status(
users=["bad:test"],
- expected_status_code=400,
+ expected_status_code=HTTPStatus.BAD_REQUEST,
expected_errcode=Codes.INVALID_PARAM,
)
@@ -1286,7 +1292,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
def _test_status(
self,
users: Optional[List[str]],
- expected_status_code: int = 200,
+ expected_status_code: int = HTTPStatus.OK,
expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None,
expected_failures: Optional[List[str]] = None,
expected_errcode: Optional[str] = None,
diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py
index 16e7ef41bc..7a88aa2cda 100644
--- a/tests/rest/client/test_directory.py
+++ b/tests/rest/client/test_directory.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
@@ -97,8 +96,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
# We use deliberately a localpart under the length threshold so
# that we can make sure that the check is done on the whole alias.
- data = {"room_alias_name": random_string(256 - len(self.hs.hostname))}
- request_data = json.dumps(data)
+ request_data = {"room_alias_name": random_string(256 - len(self.hs.hostname))}
channel = self.make_request(
"POST", url, request_data, access_token=self.user_tok
)
@@ -110,8 +108,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
# Check with an alias of allowed length. There should already be
# a test that ensures it works in test_register.py, but let's be
# as cautious as possible here.
- data = {"room_alias_name": random_string(5)}
- request_data = json.dumps(data)
+ request_data = {"room_alias_name": random_string(5)}
channel = self.make_request(
"POST", url, request_data, access_token=self.user_tok
)
@@ -144,8 +141,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
# Add an alias for the room, as the appservice
alias = RoomAlias(f"asns-{random_string(5)}", self.hs.hostname).to_string()
- data = {"room_id": self.room_id}
- request_data = json.dumps(data)
+ request_data = {"room_id": self.room_id}
channel = self.make_request(
"PUT",
@@ -193,8 +189,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.hs.hostname,
)
- data = {"aliases": [self.random_alias(alias_length)]}
- request_data = json.dumps(data)
+ request_data = {"aliases": [self.random_alias(alias_length)]}
channel = self.make_request(
"PUT", url, request_data, access_token=self.user_tok
@@ -206,8 +201,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
) -> str:
alias = self.random_alias(alias_length)
url = "/_matrix/client/r0/directory/room/%s" % alias
- data = {"room_id": self.room_id}
- request_data = json.dumps(data)
+ request_data = {"room_id": self.room_id}
channel = self.make_request(
"PUT", url, request_data, access_token=self.user_tok
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index 299b9d21e2..dc17c9d113 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
@@ -51,12 +50,11 @@ class IdentityTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
room_id = channel.json_body["room_id"]
- params = {
+ request_data = {
"id_server": "testis",
"medium": "email",
"address": "test@example.com",
}
- request_data = json.dumps(params)
request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
channel = self.make_request(
b"POST", request_url, request_data, access_token=tok
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index f6efa5fe37..a2958f6959 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import time
import urllib.parse
+from http import HTTPStatus
from typing import Any, Dict, List, Optional
from unittest.mock import Mock
from urllib.parse import urlencode
@@ -261,20 +261,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request(b"POST", LOGIN_URL, params)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
access_token = channel.json_body["access_token"]
device_id = channel.json_body["device_id"]
# we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
- self.assertEqual(channel.code, 401, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True)
@@ -288,7 +288,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# more requests with the expired token should still return a soft-logout
self.reactor.advance(3600)
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
- self.assertEqual(channel.code, 401, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True)
@@ -296,7 +296,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self._delete_device(access_token_2, "kermit", "monkey", device_id)
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
- self.assertEqual(channel.code, 401, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], False)
@@ -307,7 +307,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
b"DELETE", "devices/" + device_id, access_token=access_token
)
- self.assertEqual(channel.code, 401, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
# check it's a UI-Auth fail
self.assertEqual(
set(channel.json_body.keys()),
@@ -330,7 +330,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
access_token=access_token,
content={"auth": auth},
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
@override_config({"session_lifetime": "24h"})
def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None:
@@ -341,14 +341,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
- self.assertEqual(channel.code, 401, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True)
@@ -367,14 +367,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
- self.assertEqual(channel.code, 401, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True)
@@ -399,7 +399,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
"/_matrix/client/v3/login",
- json.dumps(body).encode("utf8"),
+ body,
custom_headers=None,
)
@@ -466,7 +466,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
def test_get_login_flows(self) -> None:
"""GET /login should return password and SSO flows"""
channel = self.make_request("GET", "/_matrix/client/r0/login")
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
expected_flow_types = [
"m.login.cas",
@@ -494,14 +494,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
"""/login/sso/redirect should redirect to an identity picker"""
# first hit the redirect url, which should redirect to our idp picker
channel = self._make_sso_redirect_request(None)
- self.assertEqual(channel.code, 302, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
uri = location_headers[0]
# hitting that picker should give us some HTML
channel = self.make_request("GET", uri)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
# parse the form to check it has fields assumed elsewhere in this class
html = channel.result["body"].decode("utf-8")
@@ -530,7 +530,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ "&idp=cas",
shorthand=False,
)
- self.assertEqual(channel.code, 302, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
cas_uri = location_headers[0]
@@ -555,7 +555,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=saml",
)
- self.assertEqual(channel.code, 302, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
saml_uri = location_headers[0]
@@ -579,7 +579,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=oidc",
)
- self.assertEqual(channel.code, 302, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
oidc_uri = location_headers[0]
@@ -606,7 +606,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
# that should serve a confirmation page
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
content_type_headers = channel.headers.getRawHeaders("Content-Type")
assert content_type_headers
self.assertTrue(content_type_headers[-1].startswith("text/html"))
@@ -634,7 +634,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
"/login",
content={"type": "m.login.token", "token": login_token},
)
- self.assertEqual(chan.code, 200, chan.result)
+ self.assertEqual(chan.code, HTTPStatus.OK, chan.result)
self.assertEqual(chan.json_body["user_id"], "@user1:test")
def test_multi_sso_redirect_to_unknown(self) -> None:
@@ -643,18 +643,18 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
"GET",
"/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
)
- self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
def test_client_idp_redirect_to_unknown(self) -> None:
"""If the client tries to pick an unknown IdP, return a 404"""
channel = self._make_sso_redirect_request("xxx")
- self.assertEqual(channel.code, 404, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
def test_client_idp_redirect_to_oidc(self) -> None:
"""If the client pick a known IdP, redirect to it"""
channel = self._make_sso_redirect_request("oidc")
- self.assertEqual(channel.code, 302, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
oidc_uri = location_headers[0]
@@ -765,7 +765,7 @@ class CASTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", cas_ticket_url)
# Test that the response is HTML.
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
content_type_header_value = ""
for header in channel.result.get("headers", []):
if header[0] == b"Content-Type":
@@ -1246,7 +1246,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
)
# that should redirect to the username picker
- self.assertEqual(channel.code, 302, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
picker_url = location_headers[0]
@@ -1290,7 +1290,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
("Content-Length", str(len(content))),
],
)
- self.assertEqual(chan.code, 302, chan.result)
+ self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result)
location_headers = chan.headers.getRawHeaders("Location")
assert location_headers
@@ -1300,7 +1300,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
path=location_headers[0],
custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
)
- self.assertEqual(chan.code, 302, chan.result)
+ self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result)
location_headers = chan.headers.getRawHeaders("Location")
assert location_headers
@@ -1325,5 +1325,5 @@ class UsernamePickerTestCase(HomeserverTestCase):
"/login",
content={"type": "m.login.token", "token": login_token},
)
- self.assertEqual(chan.code, 200, chan.result)
+ self.assertEqual(chan.code, HTTPStatus.OK, chan.result)
self.assertEqual(chan.json_body["user_id"], "@bobby:test")
diff --git a/tests/rest/client/test_password_policy.py b/tests/rest/client/test_password_policy.py
index 3a74d2e96c..e19d21d6ee 100644
--- a/tests/rest/client/test_password_policy.py
+++ b/tests/rest/client/test_password_policy.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
@@ -89,7 +88,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
)
def test_password_too_short(self) -> None:
- request_data = json.dumps({"username": "kermit", "password": "shorty"})
+ request_data = {"username": "kermit", "password": "shorty"}
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
@@ -100,7 +99,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
)
def test_password_no_digit(self) -> None:
- request_data = json.dumps({"username": "kermit", "password": "longerpassword"})
+ request_data = {"username": "kermit", "password": "longerpassword"}
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
@@ -111,7 +110,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
)
def test_password_no_symbol(self) -> None:
- request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"})
+ request_data = {"username": "kermit", "password": "l0ngerpassword"}
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
@@ -122,7 +121,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
)
def test_password_no_uppercase(self) -> None:
- request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"})
+ request_data = {"username": "kermit", "password": "l0ngerpassword!"}
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
@@ -133,7 +132,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
)
def test_password_no_lowercase(self) -> None:
- request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"})
+ request_data = {"username": "kermit", "password": "L0NGERPASSWORD!"}
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
@@ -144,7 +143,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
)
def test_password_compliant(self) -> None:
- request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"})
+ request_data = {"username": "kermit", "password": "L0ngerpassword!"}
channel = self.make_request("POST", self.register_url, request_data)
# Getting a 401 here means the password has passed validation and the server has
@@ -161,16 +160,14 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
user_id = self.register_user("kermit", compliant_password)
tok = self.login("kermit", compliant_password)
- request_data = json.dumps(
- {
- "new_password": not_compliant_password,
- "auth": {
- "password": compliant_password,
- "type": LoginType.PASSWORD,
- "user": user_id,
- },
- }
- )
+ request_data = {
+ "new_password": not_compliant_password,
+ "auth": {
+ "password": compliant_password,
+ "type": LoginType.PASSWORD,
+ "user": user_id,
+ },
+ }
channel = self.make_request(
"POST",
"/_matrix/client/r0/account/password",
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index afb08b2736..071b488cc0 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
-import json
import os
from typing import Any, Dict, List, Tuple
@@ -62,9 +61,10 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
self.hs.get_datastores().main.services_cache.append(appservice)
- request_data = json.dumps(
- {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
- )
+ request_data = {
+ "username": "as_user_kermit",
+ "type": APP_SERVICE_REGISTRATION_TYPE,
+ }
channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
@@ -85,7 +85,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
self.hs.get_datastores().main.services_cache.append(appservice)
- request_data = json.dumps({"username": "as_user_kermit"})
+ request_data = {"username": "as_user_kermit"}
channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
@@ -95,9 +95,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
def test_POST_appservice_registration_invalid(self) -> None:
self.appservice = None # no application service exists
- request_data = json.dumps(
- {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
- )
+ request_data = {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
@@ -105,14 +103,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"401", channel.result)
def test_POST_bad_password(self) -> None:
- request_data = json.dumps({"username": "kermit", "password": 666})
+ request_data = {"username": "kermit", "password": 666}
channel = self.make_request(b"POST", self.url, request_data)
self.assertEqual(channel.result["code"], b"400", channel.result)
self.assertEqual(channel.json_body["error"], "Invalid password")
def test_POST_bad_username(self) -> None:
- request_data = json.dumps({"username": 777, "password": "monkey"})
+ request_data = {"username": 777, "password": "monkey"}
channel = self.make_request(b"POST", self.url, request_data)
self.assertEqual(channel.result["code"], b"400", channel.result)
@@ -121,13 +119,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
def test_POST_user_valid(self) -> None:
user_id = "@kermit:test"
device_id = "frogfone"
- params = {
+ request_data = {
"username": "kermit",
"password": "monkey",
"device_id": device_id,
"auth": {"type": LoginType.DUMMY},
}
- request_data = json.dumps(params)
channel = self.make_request(b"POST", self.url, request_data)
det_data = {
@@ -140,7 +137,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
@override_config({"enable_registration": False})
def test_POST_disabled_registration(self) -> None:
- request_data = json.dumps({"username": "kermit", "password": "monkey"})
+ request_data = {"username": "kermit", "password": "monkey"}
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
channel = self.make_request(b"POST", self.url, request_data)
@@ -188,13 +185,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
def test_POST_ratelimiting(self) -> None:
for i in range(0, 6):
- params = {
+ request_data = {
"username": "kermit" + str(i),
"password": "monkey",
"device_id": "frogfone",
"auth": {"type": LoginType.DUMMY},
}
- request_data = json.dumps(params)
channel = self.make_request(b"POST", self.url, request_data)
if i == 5:
@@ -234,7 +230,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
}
# Request without auth to get flows and session
- channel = self.make_request(b"POST", self.url, json.dumps(params))
+ channel = self.make_request(b"POST", self.url, params)
self.assertEqual(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"]
# Synapse adds a dummy stage to differentiate flows where otherwise one
@@ -251,8 +247,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"token": token,
"session": session,
}
- request_data = json.dumps(params)
- channel = self.make_request(b"POST", self.url, request_data)
+ channel = self.make_request(b"POST", self.url, params)
self.assertEqual(channel.result["code"], b"401", channel.result)
completed = channel.json_body["completed"]
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
@@ -262,8 +257,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"type": LoginType.DUMMY,
"session": session,
}
- request_data = json.dumps(params)
- channel = self.make_request(b"POST", self.url, request_data)
+ channel = self.make_request(b"POST", self.url, params)
det_data = {
"user_id": f"@{username}:{self.hs.hostname}",
"home_server": self.hs.hostname,
@@ -290,7 +284,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"password": "monkey",
}
# Request without auth to get session
- channel = self.make_request(b"POST", self.url, json.dumps(params))
+ channel = self.make_request(b"POST", self.url, params)
session = channel.json_body["session"]
# Test with token param missing (invalid)
@@ -298,21 +292,21 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"type": LoginType.REGISTRATION_TOKEN,
"session": session,
}
- channel = self.make_request(b"POST", self.url, json.dumps(params))
+ channel = self.make_request(b"POST", self.url, params)
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM)
self.assertEqual(channel.json_body["completed"], [])
# Test with non-string (invalid)
params["auth"]["token"] = 1234
- channel = self.make_request(b"POST", self.url, json.dumps(params))
+ channel = self.make_request(b"POST", self.url, params)
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
self.assertEqual(channel.json_body["completed"], [])
# Test with unknown token (invalid)
params["auth"]["token"] = "1234"
- channel = self.make_request(b"POST", self.url, json.dumps(params))
+ channel = self.make_request(b"POST", self.url, params)
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], [])
@@ -337,9 +331,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
params1: JsonDict = {"username": "bert", "password": "monkey"}
params2: JsonDict = {"username": "ernie", "password": "monkey"}
# Do 2 requests without auth to get two session IDs
- channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+ channel1 = self.make_request(b"POST", self.url, params1)
session1 = channel1.json_body["session"]
- channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+ channel2 = self.make_request(b"POST", self.url, params2)
session2 = channel2.json_body["session"]
# Use token with session1 and check `pending` is 1
@@ -348,9 +342,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"token": token,
"session": session1,
}
- self.make_request(b"POST", self.url, json.dumps(params1))
+ self.make_request(b"POST", self.url, params1)
# Repeat request to make sure pending isn't increased again
- self.make_request(b"POST", self.url, json.dumps(params1))
+ self.make_request(b"POST", self.url, params1)
pending = self.get_success(
store.db_pool.simple_select_one_onecol(
"registration_tokens",
@@ -366,14 +360,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"token": token,
"session": session2,
}
- channel = self.make_request(b"POST", self.url, json.dumps(params2))
+ channel = self.make_request(b"POST", self.url, params2)
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], [])
# Complete registration with session1
params1["auth"]["type"] = LoginType.DUMMY
- self.make_request(b"POST", self.url, json.dumps(params1))
+ self.make_request(b"POST", self.url, params1)
# Check pending=0 and completed=1
res = self.get_success(
store.db_pool.simple_select_one(
@@ -386,7 +380,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(res["completed"], 1)
# Check auth still fails when using token with session2
- channel = self.make_request(b"POST", self.url, json.dumps(params2))
+ channel = self.make_request(b"POST", self.url, params2)
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], [])
@@ -411,7 +405,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
params: JsonDict = {"username": "kermit", "password": "monkey"}
# Request without auth to get session
- channel = self.make_request(b"POST", self.url, json.dumps(params))
+ channel = self.make_request(b"POST", self.url, params)
session = channel.json_body["session"]
# Check authentication fails with expired token
@@ -420,7 +414,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"token": token,
"session": session,
}
- channel = self.make_request(b"POST", self.url, json.dumps(params))
+ channel = self.make_request(b"POST", self.url, params)
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], [])
@@ -435,7 +429,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
# Check authentication succeeds
- channel = self.make_request(b"POST", self.url, json.dumps(params))
+ channel = self.make_request(b"POST", self.url, params)
completed = channel.json_body["completed"]
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
@@ -460,9 +454,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
# Do 2 requests without auth to get two session IDs
params1: JsonDict = {"username": "bert", "password": "monkey"}
params2: JsonDict = {"username": "ernie", "password": "monkey"}
- channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+ channel1 = self.make_request(b"POST", self.url, params1)
session1 = channel1.json_body["session"]
- channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+ channel2 = self.make_request(b"POST", self.url, params2)
session2 = channel2.json_body["session"]
# Use token with both sessions
@@ -471,18 +465,18 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"token": token,
"session": session1,
}
- self.make_request(b"POST", self.url, json.dumps(params1))
+ self.make_request(b"POST", self.url, params1)
params2["auth"] = {
"type": LoginType.REGISTRATION_TOKEN,
"token": token,
"session": session2,
}
- self.make_request(b"POST", self.url, json.dumps(params2))
+ self.make_request(b"POST", self.url, params2)
# Complete registration with session1
params1["auth"]["type"] = LoginType.DUMMY
- self.make_request(b"POST", self.url, json.dumps(params1))
+ self.make_request(b"POST", self.url, params1)
# Check `result` of registration token stage for session1 is `True`
result1 = self.get_success(
@@ -550,7 +544,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
# Do request without auth to get a session ID
params: JsonDict = {"username": "kermit", "password": "monkey"}
- channel = self.make_request(b"POST", self.url, json.dumps(params))
+ channel = self.make_request(b"POST", self.url, params)
session = channel.json_body["session"]
# Use token
@@ -559,7 +553,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"token": token,
"session": session,
}
- self.make_request(b"POST", self.url, json.dumps(params))
+ self.make_request(b"POST", self.url, params)
# Delete token
self.get_success(
@@ -592,9 +586,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"require_at_registration": True,
},
"account_threepid_delegates": {
- "email": "https://id_server",
"msisdn": "https://id_server",
},
+ "email": {"notif_from": "Synapse <synapse@example.com>"},
}
)
def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None:
@@ -827,8 +821,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
admin_tok = self.login("admin", "adminpassword")
url = "/_synapse/admin/v1/account_validity/validity"
- params = {"user_id": user_id}
- request_data = json.dumps(params)
+ request_data = {"user_id": user_id}
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
self.assertEqual(channel.result["code"], b"200", channel.result)
@@ -845,12 +838,11 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
admin_tok = self.login("admin", "adminpassword")
url = "/_synapse/admin/v1/account_validity/validity"
- params = {
+ request_data = {
"user_id": user_id,
"expiration_ts": 0,
"enable_renewal_emails": False,
}
- request_data = json.dumps(params)
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
self.assertEqual(channel.result["code"], b"200", channel.result)
@@ -870,12 +862,11 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
admin_tok = self.login("admin", "adminpassword")
url = "/_synapse/admin/v1/account_validity/validity"
- params = {
+ request_data = {
"user_id": user_id,
"expiration_ts": 0,
"enable_renewal_emails": False,
}
- request_data = json.dumps(params)
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
self.assertEqual(channel.result["code"], b"200", channel.result)
@@ -1041,16 +1032,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
(user_id, tok) = self.create_user()
- request_data = json.dumps(
- {
- "auth": {
- "type": "m.login.password",
- "user": user_id,
- "password": "monkey",
- },
- "erase": False,
- }
- )
+ request_data = {
+ "auth": {
+ "type": "m.login.password",
+ "user": user_id,
+ "password": "monkey",
+ },
+ "erase": False,
+ }
channel = self.make_request(
"POST", "account/deactivate", request_data, access_token=tok
)
diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py
index 20a259fc43..ad0d0209f7 100644
--- a/tests/rest/client/test_report_event.py
+++ b/tests/rest/client/test_report_event.py
@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
-
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -77,10 +75,7 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
def _assert_status(self, response_status: int, data: JsonDict) -> None:
channel = self.make_request(
- "POST",
- self.report_path,
- json.dumps(data),
- access_token=self.other_user_tok,
+ "POST", self.report_path, data, access_token=self.other_user_tok
)
self.assertEqual(
response_status, int(channel.result["code"]), msg=channel.result["body"]
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index df7ffbe545..c45cb32090 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -18,6 +18,7 @@
"""Tests REST events for /rooms paths."""
import json
+from http import HTTPStatus
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from unittest.mock import Mock, call
from urllib import parse as urlparse
@@ -104,7 +105,7 @@ class RoomPermissionsTestCase(RoomBase):
channel = self.make_request(
"PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}'
)
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
# set topic for public room
channel = self.make_request(
@@ -112,7 +113,7 @@ class RoomPermissionsTestCase(RoomBase):
("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"),
b'{"topic":"Public Room Topic"}',
)
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
# auth as user_id now
self.helper.auth_user_id = self.user_id
@@ -134,28 +135,28 @@ class RoomPermissionsTestCase(RoomBase):
"/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,),
msg_content,
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# send message in created room not joined (no state), expect 403
channel = self.make_request("PUT", send_msg_path(), msg_content)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# send message in created room and invited, expect 403
self.helper.invite(
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
)
channel = self.make_request("PUT", send_msg_path(), msg_content)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# send message in created room and joined, expect 200
self.helper.join(room=self.created_rmid, user=self.user_id)
channel = self.make_request("PUT", send_msg_path(), msg_content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# send message in created room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id)
channel = self.make_request("PUT", send_msg_path(), msg_content)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def test_topic_perms(self) -> None:
topic_content = b'{"topic":"My Topic Name"}'
@@ -165,28 +166,28 @@ class RoomPermissionsTestCase(RoomBase):
channel = self.make_request(
"PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
channel = self.make_request(
"GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# set/get topic in created PRIVATE room not joined, expect 403
channel = self.make_request("PUT", topic_path, topic_content)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", topic_path)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# set topic in created PRIVATE room and invited, expect 403
self.helper.invite(
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
)
channel = self.make_request("PUT", topic_path, topic_content)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# get topic in created PRIVATE room and invited, expect 403
channel = self.make_request("GET", topic_path)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# set/get topic in created PRIVATE room and joined, expect 200
self.helper.join(room=self.created_rmid, user=self.user_id)
@@ -194,25 +195,25 @@ class RoomPermissionsTestCase(RoomBase):
# Only room ops can set topic by default
self.helper.auth_user_id = self.rmcreator_id
channel = self.make_request("PUT", topic_path, topic_content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.helper.auth_user_id = self.user_id
channel = self.make_request("GET", topic_path)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body)
# set/get topic in created PRIVATE room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id)
channel = self.make_request("PUT", topic_path, topic_content)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", topic_path)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# get topic in PUBLIC room, not joined, expect 403
channel = self.make_request(
"GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# set topic in PUBLIC room, not joined, expect 403
channel = self.make_request(
@@ -220,7 +221,7 @@ class RoomPermissionsTestCase(RoomBase):
"/rooms/%s/state/m.room.topic" % self.created_public_rmid,
topic_content,
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def _test_get_membership(
self, room: str, members: Iterable = frozenset(), expect_code: int = 200
@@ -309,14 +310,14 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=self.rmcreator_id,
membership=Membership.JOIN,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
self.helper.change_membership(
room=room,
src=self.user_id,
targ=self.rmcreator_id,
membership=Membership.LEAVE,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
def test_joined_permissions(self) -> None:
@@ -342,7 +343,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=other,
membership=Membership.JOIN,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
# set left of other, expect 403
@@ -351,7 +352,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=other,
membership=Membership.LEAVE,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
# set left of self, expect 200
@@ -371,7 +372,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=usr,
membership=Membership.INVITE,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
self.helper.change_membership(
@@ -379,7 +380,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=usr,
membership=Membership.JOIN,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
# It is always valid to LEAVE if you've already left (currently.)
@@ -388,7 +389,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=self.rmcreator_id,
membership=Membership.LEAVE,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
# tests the "from banned" line from the table in https://spec.matrix.org/unstable/client-server-api/#mroommember
@@ -405,7 +406,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=other,
membership=Membership.BAN,
- expect_code=403, # expect failure
+ expect_code=HTTPStatus.FORBIDDEN, # expect failure
expect_errcode=Codes.FORBIDDEN,
)
@@ -415,7 +416,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.rmcreator_id,
targ=other,
membership=Membership.BAN,
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
# from ban to invite: Must never happen.
@@ -424,7 +425,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.rmcreator_id,
targ=other,
membership=Membership.INVITE,
- expect_code=403, # expect failure
+ expect_code=HTTPStatus.FORBIDDEN, # expect failure
expect_errcode=Codes.BAD_STATE,
)
@@ -434,7 +435,7 @@ class RoomPermissionsTestCase(RoomBase):
src=other,
targ=other,
membership=Membership.JOIN,
- expect_code=403, # expect failure
+ expect_code=HTTPStatus.FORBIDDEN, # expect failure
expect_errcode=Codes.BAD_STATE,
)
@@ -444,7 +445,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.rmcreator_id,
targ=other,
membership=Membership.BAN,
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
# from ban to knock: Must never happen.
@@ -453,7 +454,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.rmcreator_id,
targ=other,
membership=Membership.KNOCK,
- expect_code=403, # expect failure
+ expect_code=HTTPStatus.FORBIDDEN, # expect failure
expect_errcode=Codes.BAD_STATE,
)
@@ -463,7 +464,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=other,
membership=Membership.LEAVE,
- expect_code=403, # expect failure
+ expect_code=HTTPStatus.FORBIDDEN, # expect failure
expect_errcode=Codes.FORBIDDEN,
)
@@ -473,7 +474,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.rmcreator_id,
targ=other,
membership=Membership.LEAVE,
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
@@ -493,7 +494,7 @@ class RoomStateTestCase(RoomBase):
"/rooms/%s/state" % room_id,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertCountEqual(
[state_event["type"] for state_event in channel.json_body],
{
@@ -516,7 +517,7 @@ class RoomStateTestCase(RoomBase):
"/rooms/%s/state/m.room.member/%s" % (room_id, self.user_id),
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual(channel.json_body, {"membership": "join"})
@@ -530,16 +531,16 @@ class RoomsMemberListTestCase(RoomBase):
def test_get_member_list(self) -> None:
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
def test_get_member_list_no_room(self) -> None:
channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission(self) -> None:
room_id = self.helper.create_room_as("@some_other_guy:red")
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission_with_at_token(self) -> None:
"""
@@ -550,7 +551,7 @@ class RoomsMemberListTestCase(RoomBase):
# first sync to get an at token
channel = self.make_request("GET", "/sync")
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
sync_token = channel.json_body["next_batch"]
# check that permission is denied for @sid1:red to get the
@@ -559,7 +560,7 @@ class RoomsMemberListTestCase(RoomBase):
"GET",
f"/rooms/{room_id}/members?at={sync_token}",
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission_former_member(self) -> None:
"""
@@ -572,14 +573,14 @@ class RoomsMemberListTestCase(RoomBase):
# check that the user can see the member list to start with
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# ban the user
self.helper.change_membership(room_id, "@alice:red", self.user_id, "ban")
# check the user can no longer see the member list
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission_former_member_with_at_token(self) -> None:
"""
@@ -593,14 +594,14 @@ class RoomsMemberListTestCase(RoomBase):
# sync to get an at token
channel = self.make_request("GET", "/sync")
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
sync_token = channel.json_body["next_batch"]
# check that the user can see the member list to start with
channel = self.make_request(
"GET", "/rooms/%s/members?at=%s" % (room_id, sync_token)
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# ban the user (Note: the user is actually allowed to see this event and
# state so that they know they're banned!)
@@ -612,14 +613,14 @@ class RoomsMemberListTestCase(RoomBase):
# now, with the original user, sync again to get a new at token
channel = self.make_request("GET", "/sync")
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
sync_token = channel.json_body["next_batch"]
# check the user can no longer see the updated member list
channel = self.make_request(
"GET", "/rooms/%s/members?at=%s" % (room_id, sync_token)
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def test_get_member_list_mixed_memberships(self) -> None:
room_creator = "@some_other_guy:red"
@@ -628,17 +629,17 @@ class RoomsMemberListTestCase(RoomBase):
self.helper.invite(room=room_id, src=room_creator, targ=self.user_id)
# can't see list if you're just invited.
channel = self.make_request("GET", room_path)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
self.helper.join(room=room_id, user=self.user_id)
# can see list now joined
channel = self.make_request("GET", room_path)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.helper.leave(room=room_id, user=self.user_id)
# can see old list once left
channel = self.make_request("GET", room_path)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
def test_get_member_list_cancellation(self) -> None:
"""Test cancellation of a `/rooms/$room_id/members` request."""
@@ -651,7 +652,7 @@ class RoomsMemberListTestCase(RoomBase):
"/rooms/%s/members" % room_id,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual(len(channel.json_body["chunk"]), 1)
self.assertLessEqual(
{
@@ -671,7 +672,7 @@ class RoomsMemberListTestCase(RoomBase):
# first sync to get an at token
channel = self.make_request("GET", "/sync")
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
sync_token = channel.json_body["next_batch"]
channel = make_request_with_cancellation_test(
@@ -682,7 +683,7 @@ class RoomsMemberListTestCase(RoomBase):
"/rooms/%s/members?at=%s" % (room_id, sync_token),
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual(len(channel.json_body["chunk"]), 1)
self.assertLessEqual(
{
@@ -706,10 +707,10 @@ class RoomsCreateTestCase(RoomBase):
# POST with no config keys, expect new room id
channel = self.make_request("POST", "/createRoom", "{}")
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
- self.assertEqual(37, channel.resource_usage.db_txn_count)
+ self.assertEqual(44, channel.resource_usage.db_txn_count)
def test_post_room_initial_state(self) -> None:
# POST with initial_state config key, expect new room id
@@ -719,21 +720,21 @@ class RoomsCreateTestCase(RoomBase):
b'{"initial_state":[{"type": "m.bridge", "content": {}}]}',
)
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
- self.assertEqual(41, channel.resource_usage.db_txn_count)
+ self.assertEqual(50, channel.resource_usage.db_txn_count)
def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id
channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}')
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_custom_key(self) -> None:
# POST with custom config keys, expect new room id
channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}')
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_known_and_unknown_keys(self) -> None:
@@ -741,16 +742,16 @@ class RoomsCreateTestCase(RoomBase):
channel = self.make_request(
"POST", "/createRoom", b'{"visibility":"private","custom":"things"}'
)
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_invalid_content(self) -> None:
# POST with invalid content / paths, expect 400
channel = self.make_request("POST", "/createRoom", b'{"visibili')
- self.assertEqual(400, channel.code)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code)
channel = self.make_request("POST", "/createRoom", b'["hello"]')
- self.assertEqual(400, channel.code)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code)
def test_post_room_invitees_invalid_mxid(self) -> None:
# POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088
@@ -758,7 +759,7 @@ class RoomsCreateTestCase(RoomBase):
channel = self.make_request(
"POST", "/createRoom", b'{"invite":["@alice:example.com "]}'
)
- self.assertEqual(400, channel.code)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code)
@unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}})
def test_post_room_invitees_ratelimit(self) -> None:
@@ -769,20 +770,18 @@ class RoomsCreateTestCase(RoomBase):
# Build the request's content. We use local MXIDs because invites over federation
# are more difficult to mock.
- content = json.dumps(
- {
- "invite": [
- "@alice1:red",
- "@alice2:red",
- "@alice3:red",
- "@alice4:red",
- ]
- }
- ).encode("utf8")
+ content = {
+ "invite": [
+ "@alice1:red",
+ "@alice2:red",
+ "@alice3:red",
+ "@alice4:red",
+ ]
+ }
# Test that the invites are correctly ratelimited.
channel = self.make_request("POST", "/createRoom", content)
- self.assertEqual(400, channel.code)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code)
self.assertEqual(
"Cannot invite so many users at once",
channel.json_body["error"],
@@ -795,7 +794,7 @@ class RoomsCreateTestCase(RoomBase):
# Test that the invites aren't ratelimited anymore.
channel = self.make_request("POST", "/createRoom", content)
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
def test_spam_checker_may_join_room_deprecated(self) -> None:
"""Tests that the user_may_join_room spam checker callback is correctly bypassed
@@ -819,7 +818,7 @@ class RoomsCreateTestCase(RoomBase):
"/createRoom",
{},
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
self.assertEqual(join_mock.call_count, 0)
@@ -845,7 +844,7 @@ class RoomsCreateTestCase(RoomBase):
"/createRoom",
{},
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
self.assertEqual(join_mock.call_count, 0)
@@ -865,7 +864,7 @@ class RoomsCreateTestCase(RoomBase):
"/createRoom",
{},
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
self.assertEqual(join_mock.call_count, 0)
@@ -882,54 +881,68 @@ class RoomTopicTestCase(RoomBase):
def test_invalid_puts(self) -> None:
# missing keys or invalid json
channel = self.make_request("PUT", self.path, "{}")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", self.path, '{"_name":"bo"}')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", self.path, '{"nao')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request(
"PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]'
)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", self.path, "text only")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", self.path, "")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
# valid key, wrong type
content = '{"topic":["Topic name"]}'
channel = self.make_request("PUT", self.path, content)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
def test_rooms_topic(self) -> None:
# nothing should be there
channel = self.make_request("GET", self.path)
- self.assertEqual(404, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.result["body"])
# valid put
content = '{"topic":"Topic name"}'
channel = self.make_request("PUT", self.path, content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# valid get
channel = self.make_request("GET", self.path)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body)
def test_rooms_topic_with_extra_keys(self) -> None:
# valid put with extra keys
content = '{"topic":"Seasons","subtopic":"Summer"}'
channel = self.make_request("PUT", self.path, content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# valid get
channel = self.make_request("GET", self.path)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body)
@@ -945,22 +958,34 @@ class RoomMemberStateTestCase(RoomBase):
path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
# missing keys or invalid json
channel = self.make_request("PUT", path, "{}")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, '{"_name":"bo"}')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, '{"nao')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, "text only")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, "")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
# valid keys, wrong types
content = '{"membership":["%s","%s","%s"]}' % (
@@ -969,7 +994,9 @@ class RoomMemberStateTestCase(RoomBase):
Membership.LEAVE,
)
channel = self.make_request("PUT", path, content.encode("ascii"))
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
def test_rooms_members_self(self) -> None:
path = "/rooms/%s/state/m.room.member/%s" % (
@@ -980,10 +1007,10 @@ class RoomMemberStateTestCase(RoomBase):
# valid join message (NOOP since we made the room)
content = '{"membership":"%s"}' % Membership.JOIN
channel = self.make_request("PUT", path, content.encode("ascii"))
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", path, content=b"")
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
expected_response = {"membership": Membership.JOIN}
self.assertEqual(expected_response, channel.json_body)
@@ -998,10 +1025,10 @@ class RoomMemberStateTestCase(RoomBase):
# valid invite message
content = '{"membership":"%s"}' % Membership.INVITE
channel = self.make_request("PUT", path, content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", path, content=b"")
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual(json.loads(content), channel.json_body)
def test_rooms_members_other_custom_keys(self) -> None:
@@ -1017,10 +1044,10 @@ class RoomMemberStateTestCase(RoomBase):
"Join us!",
)
channel = self.make_request("PUT", path, content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", path, content=b"")
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual(json.loads(content), channel.json_body)
@@ -1137,7 +1164,9 @@ class RoomJoinTestCase(RoomBase):
# Now make the callback deny all room joins, and check that a join actually fails.
return_value = False
- self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2)
+ self.helper.join(
+ self.room3, self.user2, expect_code=HTTPStatus.FORBIDDEN, tok=self.tok2
+ )
def test_spam_checker_may_join_room(self) -> None:
"""Tests that the user_may_join_room spam checker callback is correctly called
@@ -1205,7 +1234,7 @@ class RoomJoinTestCase(RoomBase):
self.helper.join(
self.room3,
self.user2,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
expect_errcode=return_value,
tok=self.tok2,
)
@@ -1216,7 +1245,7 @@ class RoomJoinTestCase(RoomBase):
self.helper.join(
self.room3,
self.user2,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
expect_errcode=return_value[0],
tok=self.tok2,
expect_additional_fields=return_value[1],
@@ -1270,7 +1299,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
# Update the display name for the user.
path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
channel = self.make_request("PUT", path, {"displayname": "John Doe"})
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# Check that all the rooms have been sent a profile update into.
for room_id in room_ids:
@@ -1335,71 +1364,93 @@ class RoomMessagesTestCase(RoomBase):
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
# missing keys or invalid json
channel = self.make_request("PUT", path, b"{}")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, b'{"_name":"bo"}')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, b'{"nao')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, b"text only")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, b"")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
def test_rooms_messages_sent(self) -> None:
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
content = b'{"body":"test","msgtype":{"type":"a"}}'
channel = self.make_request("PUT", path, content)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
# custom message types
content = b'{"body":"test","msgtype":"test.custom.text"}'
channel = self.make_request("PUT", path, content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# m.text message type
path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id))
content = b'{"body":"test2","msgtype":"m.text"}'
channel = self.make_request("PUT", path, content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
@parameterized.expand(
[
# Allow
param(
- name="NOT_SPAM", value="NOT_SPAM", expected_code=200, expected_fields={}
+ name="NOT_SPAM",
+ value="NOT_SPAM",
+ expected_code=HTTPStatus.OK,
+ expected_fields={},
+ ),
+ param(
+ name="False",
+ value=False,
+ expected_code=HTTPStatus.OK,
+ expected_fields={},
),
- param(name="False", value=False, expected_code=200, expected_fields={}),
# Block
param(
name="scalene string",
value="ANY OTHER STRING",
- expected_code=403,
+ expected_code=HTTPStatus.FORBIDDEN,
expected_fields={"errcode": "M_FORBIDDEN"},
),
param(
name="True",
value=True,
- expected_code=403,
+ expected_code=HTTPStatus.FORBIDDEN,
expected_fields={"errcode": "M_FORBIDDEN"},
),
param(
name="Code",
value=Codes.LIMIT_EXCEEDED,
- expected_code=403,
+ expected_code=HTTPStatus.FORBIDDEN,
expected_fields={"errcode": "M_LIMIT_EXCEEDED"},
),
param(
name="Tuple",
value=(Codes.SERVER_NOT_TRUSTED, {"additional_field": "12345"}),
- expected_code=403,
+ expected_code=HTTPStatus.FORBIDDEN,
expected_fields={
"errcode": "M_SERVER_NOT_TRUSTED",
"additional_field": "12345",
@@ -1584,7 +1635,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
channel = self.make_request("PUT", path, "{}")
# Then I am allowed
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
def test_normal_user_can_not_post_state_event(self) -> None:
# Given I am a normal member of a room
@@ -1598,7 +1649,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
channel = self.make_request("PUT", path, "{}")
# Then I am not allowed because state events require PL>=50
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
self.assertEqual(
"You don't have permission to post that to the room. "
"user_level (0) < send_level (50)",
@@ -1625,7 +1676,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
channel = self.make_request("PUT", path, "{}")
# Then I am allowed
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
@unittest.override_config(
{
@@ -1653,7 +1704,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
channel = self.make_request("PUT", path, "{}")
# Then I am not allowed
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
@unittest.override_config(
{
@@ -1681,7 +1732,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
channel = self.make_request("PUT", path, "{}")
# Then I am not allowed
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
self.assertEqual(
"You don't have permission to post that to the room. "
+ "user_level (0) < send_level (1)",
@@ -1712,7 +1763,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
# Then I am not allowed because the public_chat config does not
# affect this room, because this room is a private_chat
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
self.assertEqual(
"You don't have permission to post that to the room. "
+ "user_level (0) < send_level (50)",
@@ -1731,7 +1782,7 @@ class RoomInitialSyncTestCase(RoomBase):
def test_initial_sync(self) -> None:
channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id)
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
self.assertEqual(self.room_id, channel.json_body["room_id"])
self.assertEqual("join", channel.json_body["membership"])
@@ -1774,7 +1825,7 @@ class RoomMessageListTestCase(RoomBase):
channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
self.assertTrue("start" in channel.json_body)
self.assertEqual(token, channel.json_body["start"])
self.assertTrue("chunk" in channel.json_body)
@@ -1785,7 +1836,7 @@ class RoomMessageListTestCase(RoomBase):
channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
self.assertTrue("start" in channel.json_body)
self.assertEqual(token, channel.json_body["start"])
self.assertTrue("chunk" in channel.json_body)
@@ -1824,7 +1875,7 @@ class RoomMessageListTestCase(RoomBase):
json.dumps({"types": [EventTypes.Message]}),
),
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
chunk = channel.json_body["chunk"]
self.assertEqual(len(chunk), 2, [event["content"] for event in chunk])
@@ -1852,7 +1903,7 @@ class RoomMessageListTestCase(RoomBase):
json.dumps({"types": [EventTypes.Message]}),
),
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
chunk = channel.json_body["chunk"]
self.assertEqual(len(chunk), 1, [event["content"] for event in chunk])
@@ -1869,7 +1920,7 @@ class RoomMessageListTestCase(RoomBase):
json.dumps({"types": [EventTypes.Message]}),
),
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
chunk = channel.json_body["chunk"]
self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])
@@ -1997,14 +2048,14 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
def test_restricted_no_auth(self) -> None:
channel = self.make_request("GET", self.url)
- self.assertEqual(channel.code, 401, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
def test_restricted_auth(self) -> None:
self.register_user("user", "pass")
tok = self.login("user", "pass")
channel = self.make_request("GET", self.url, access_token=tok)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
@@ -2123,7 +2174,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
content={"filter": search_filter},
access_token=self.token,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.federation_client.get_public_rooms.assert_called_once_with( # type: ignore[attr-defined]
"testserv",
@@ -2140,7 +2191,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
# The `get_public_rooms` should be called again if the first call fails
# with a 404, when using search filters.
self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined]
- HttpResponseException(404, "Not Found", b""),
+ HttpResponseException(HTTPStatus.NOT_FOUND, "Not Found", b""),
make_awaitable({}),
)
@@ -2152,7 +2203,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
content={"filter": search_filter},
access_token=self.token,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.federation_client.get_public_rooms.assert_has_calls( # type: ignore[attr-defined]
[
@@ -2198,21 +2249,19 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
# Set a profile for the test user
self.displayname = "test user"
- data = {"displayname": self.displayname}
- request_data = json.dumps(data)
+ request_data = {"displayname": self.displayname}
channel = self.make_request(
"PUT",
"/_matrix/client/r0/profile/%s/displayname" % (self.user_id,),
request_data,
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
def test_per_room_profile_forbidden(self) -> None:
- data = {"membership": "join", "displayname": "other test user"}
- request_data = json.dumps(data)
+ request_data = {"membership": "join", "displayname": "other test user"}
channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
@@ -2220,7 +2269,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
request_data,
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
event_id = channel.json_body["event_id"]
channel = self.make_request(
@@ -2228,7 +2277,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
res_displayname = channel.json_body["content"]["displayname"]
self.assertEqual(res_displayname, self.displayname, channel.result)
@@ -2262,7 +2311,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason},
access_token=self.second_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -2276,7 +2325,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason},
access_token=self.second_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -2290,7 +2339,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.second_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -2304,7 +2353,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -2316,7 +2365,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -2328,7 +2377,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -2347,7 +2396,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason},
access_token=self.second_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -2359,7 +2408,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
),
access_token=self.creator_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
event_content = channel.json_body
@@ -2407,7 +2456,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id, json.dumps(self.FILTER_LABELS)),
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
events_before = channel.json_body["events_before"]
@@ -2437,7 +2486,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)),
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
events_before = channel.json_body["events_before"]
@@ -2472,7 +2521,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)),
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
events_before = channel.json_body["events_before"]
@@ -2552,16 +2601,14 @@ class LabelsTestCase(unittest.HomeserverTestCase):
def test_search_filter_labels(self) -> None:
"""Test that we can filter by a label on a /search request."""
- request_data = json.dumps(
- {
- "search_categories": {
- "room_events": {
- "search_term": "label",
- "filter": self.FILTER_LABELS,
- }
+ request_data = {
+ "search_categories": {
+ "room_events": {
+ "search_term": "label",
+ "filter": self.FILTER_LABELS,
}
}
- )
+ }
self._send_labelled_messages_in_room()
@@ -2589,16 +2636,14 @@ class LabelsTestCase(unittest.HomeserverTestCase):
def test_search_filter_not_labels(self) -> None:
"""Test that we can filter by the absence of a label on a /search request."""
- request_data = json.dumps(
- {
- "search_categories": {
- "room_events": {
- "search_term": "label",
- "filter": self.FILTER_NOT_LABELS,
- }
+ request_data = {
+ "search_categories": {
+ "room_events": {
+ "search_term": "label",
+ "filter": self.FILTER_NOT_LABELS,
}
}
- )
+ }
self._send_labelled_messages_in_room()
@@ -2638,16 +2683,14 @@ class LabelsTestCase(unittest.HomeserverTestCase):
"""Test that we can filter by both a label and the absence of another label on a
/search request.
"""
- request_data = json.dumps(
- {
- "search_categories": {
- "room_events": {
- "search_term": "label",
- "filter": self.FILTER_LABELS_NOT_LABELS,
- }
+ request_data = {
+ "search_categories": {
+ "room_events": {
+ "search_term": "label",
+ "filter": self.FILTER_LABELS_NOT_LABELS,
}
}
- )
+ }
self._send_labelled_messages_in_room()
@@ -2820,7 +2863,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)),
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
return channel.json_body["chunk"]
@@ -2925,7 +2968,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id),
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
events_before = channel.json_body["events_before"]
@@ -2991,7 +3034,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id),
access_token=invited_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
events_before = channel.json_body["events_before"]
@@ -3092,8 +3135,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None:
url = "/_matrix/client/r0/directory/room/" + alias
- data = {"room_id": self.room_id}
- request_data = json.dumps(data)
+ request_data = {"room_id": self.room_id}
channel = self.make_request(
"PUT", url, request_data, access_token=self.room_owner_tok
@@ -3122,8 +3164,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None:
url = "/_matrix/client/r0/directory/room/" + alias
- data = {"room_id": self.room_id}
- request_data = json.dumps(data)
+ request_data = {"room_id": self.room_id}
channel = self.make_request(
"PUT", url, request_data, access_token=self.room_owner_tok
@@ -3149,7 +3190,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
"rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
- json.dumps(content),
+ content,
access_token=self.room_owner_tok,
)
self.assertEqual(channel.code, expected_code, channel.result)
@@ -3283,7 +3324,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Mock a few functions to prevent the test from failing due to failing to talk to
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test.
- make_invite_mock = Mock(return_value=make_awaitable(0))
+ make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
self.hs.get_identity_handler().lookup_3pid = Mock(
return_value=make_awaitable(None),
@@ -3344,7 +3385,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Mock a few functions to prevent the test from failing due to failing to talk to
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test.
- make_invite_mock = Mock(return_value=make_awaitable(0))
+ make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
self.hs.get_identity_handler().lookup_3pid = Mock(
return_value=make_awaitable(None),
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index e3efd1f1b0..b085c50356 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -606,11 +606,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
self._check_unread_count(1)
# Send a read receipt to tell the server we've read the latest event.
- body = json.dumps({ReceiptTypes.READ: res["event_id"]}).encode("utf8")
channel = self.make_request(
"POST",
f"/rooms/{self.room_id}/read_markers",
- body,
+ {ReceiptTypes.READ: res["event_id"]},
access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 5eb0f243f7..9a48e9286f 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -21,7 +21,6 @@ from synapse.api.constants import EventTypes, LoginType, Membership
from synapse.api.errors import SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.rest import admin
from synapse.rest.client import account, login, profile, room
@@ -113,14 +112,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Have this homeserver skip event auth checks. This is necessary due to
# event auth checks ensuring that events were signed by the sender's homeserver.
- async def _check_event_auth(
- origin: str,
- event: EventBase,
- context: EventContext,
- *args: Any,
- **kwargs: Any,
- ) -> EventContext:
- return context
+ async def _check_event_auth(origin: Any, event: Any, context: Any) -> None:
+ pass
hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment]
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 93f749744d..105d418698 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -136,7 +136,7 @@ class RestHelper:
self.site,
"POST",
path,
- json.dumps(content).encode("utf8"),
+ content,
custom_headers=custom_headers,
)
@@ -210,7 +210,7 @@ class RestHelper:
self.site,
"POST",
path,
- json.dumps(data).encode("utf8"),
+ data,
)
assert (
@@ -309,7 +309,7 @@ class RestHelper:
self.site,
"PUT",
path,
- json.dumps(data).encode("utf8"),
+ data,
)
assert (
@@ -392,7 +392,7 @@ class RestHelper:
self.site,
"PUT",
path,
- json.dumps(content or {}).encode("utf8"),
+ content or {},
custom_headers=custom_headers,
)
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 79727c430f..d18fc13c21 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -126,7 +126,9 @@ class _TestImage:
expected_scaled: The expected bytes from scaled thumbnailing, or None if
test should just check for a valid image returned.
expected_found: True if the file should exist on the server, or False if
- a 404 is expected.
+ a 404/400 is expected.
+ unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or
+ False if the thumbnailing should succeed or a normal 404 is expected.
"""
data: bytes
@@ -135,6 +137,7 @@ class _TestImage:
expected_cropped: Optional[bytes] = None
expected_scaled: Optional[bytes] = None
expected_found: bool = True
+ unable_to_thumbnail: bool = False
@parameterized_class(
@@ -192,6 +195,7 @@ class _TestImage:
b"image/gif",
b".gif",
expected_found=False,
+ unable_to_thumbnail=True,
),
),
],
@@ -366,18 +370,29 @@ class MediaRepoTests(unittest.HomeserverTestCase):
def test_thumbnail_crop(self) -> None:
"""Test that a cropped remote thumbnail is available."""
self._test_thumbnail(
- "crop", self.test_image.expected_cropped, self.test_image.expected_found
+ "crop",
+ self.test_image.expected_cropped,
+ expected_found=self.test_image.expected_found,
+ unable_to_thumbnail=self.test_image.unable_to_thumbnail,
)
def test_thumbnail_scale(self) -> None:
"""Test that a scaled remote thumbnail is available."""
self._test_thumbnail(
- "scale", self.test_image.expected_scaled, self.test_image.expected_found
+ "scale",
+ self.test_image.expected_scaled,
+ expected_found=self.test_image.expected_found,
+ unable_to_thumbnail=self.test_image.unable_to_thumbnail,
)
def test_invalid_type(self) -> None:
"""An invalid thumbnail type is never available."""
- self._test_thumbnail("invalid", None, False)
+ self._test_thumbnail(
+ "invalid",
+ None,
+ expected_found=False,
+ unable_to_thumbnail=self.test_image.unable_to_thumbnail,
+ )
@unittest.override_config(
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
@@ -386,7 +401,12 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"""
Override the config to generate only scaled thumbnails, but request a cropped one.
"""
- self._test_thumbnail("crop", None, False)
+ self._test_thumbnail(
+ "crop",
+ None,
+ expected_found=False,
+ unable_to_thumbnail=self.test_image.unable_to_thumbnail,
+ )
@unittest.override_config(
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
@@ -395,14 +415,22 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"""
Override the config to generate only cropped thumbnails, but request a scaled one.
"""
- self._test_thumbnail("scale", None, False)
+ self._test_thumbnail(
+ "scale",
+ None,
+ expected_found=False,
+ unable_to_thumbnail=self.test_image.unable_to_thumbnail,
+ )
def test_thumbnail_repeated_thumbnail(self) -> None:
"""Test that fetching the same thumbnail works, and deleting the on disk
thumbnail regenerates it.
"""
self._test_thumbnail(
- "scale", self.test_image.expected_scaled, self.test_image.expected_found
+ "scale",
+ self.test_image.expected_scaled,
+ expected_found=self.test_image.expected_found,
+ unable_to_thumbnail=self.test_image.unable_to_thumbnail,
)
if not self.test_image.expected_found:
@@ -459,8 +487,24 @@ class MediaRepoTests(unittest.HomeserverTestCase):
)
def _test_thumbnail(
- self, method: str, expected_body: Optional[bytes], expected_found: bool
+ self,
+ method: str,
+ expected_body: Optional[bytes],
+ expected_found: bool,
+ unable_to_thumbnail: bool = False,
) -> None:
+ """Test the given thumbnailing method works as expected.
+
+ Args:
+ method: The thumbnailing method to use (crop, scale).
+ expected_body: The expected bytes from thumbnailing, or None if
+ test should just check for a valid image.
+ expected_found: True if the file should exist on the server, or False if
+ a 404/400 is expected.
+ unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or
+ False if the thumbnailing should succeed or a normal 404 is expected.
+ """
+
params = "?width=32&height=32&method=" + method
channel = make_request(
self.reactor,
@@ -496,6 +540,16 @@ class MediaRepoTests(unittest.HomeserverTestCase):
else:
# ensure that the result is at least some valid image
Image.open(BytesIO(channel.result["body"]))
+ elif unable_to_thumbnail:
+ # A 400 with a JSON body.
+ self.assertEqual(channel.code, 400)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "errcode": "M_UNKNOWN",
+ "error": "Cannot find any thumbnails for the requested media ([b'example.com', b'12345']). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
+ },
+ )
else:
# A 404 with a JSON body.
self.assertEqual(channel.code, 404)
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 38963ce4a7..46d829b062 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -143,7 +143,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
self.event_id = res["event_id"]
# Reset the event cache so the tests start with it empty
- self.store._get_event_cache.clear()
+ self.get_success(self.store._get_event_cache.clear())
def test_simple(self):
"""Test that we cache events that we pull from the DB."""
@@ -160,7 +160,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
"""
# Reset the event cache
- self.store._get_event_cache.clear()
+ self.get_success(self.store._get_event_cache.clear())
with LoggingContext("test") as ctx:
# We keep hold of the event event though we never use it.
@@ -170,7 +170,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
# Reset the event cache
- self.store._get_event_cache.clear()
+ self.get_success(self.store._get_event_cache.clear())
with LoggingContext("test") as ctx:
self.get_success(self.store.get_event(self.event_id))
@@ -345,7 +345,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
self.event_id = res["event_id"]
# Reset the event cache so the tests start with it empty
- self.store._get_event_cache.clear()
+ self.get_success(self.store._get_event_cache.clear())
@contextmanager
def blocking_get_event_calls(
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index e8c53f16d9..ba40124c8a 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from unittest.mock import Mock
-
from twisted.test.proto_helpers import MemoryReactor
+from synapse.rest import admin
+from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.storage.databases.main.event_push_actions import NotifCounts
from synapse.util import Clock
@@ -24,15 +24,14 @@ from tests.unittest import HomeserverTestCase
USER_ID = "@user:example.com"
-PlAIN_NOTIF = ["notify", {"set_tweak": "highlight", "value": False}]
-HIGHLIGHT = [
- "notify",
- {"set_tweak": "sound", "value": "default"},
- {"set_tweak": "highlight"},
-]
-
class EventPushActionsStoreTestCase(HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
persist_events_store = hs.get_datastores().persist_events
@@ -54,154 +53,118 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
)
def test_count_aggregation(self) -> None:
- room_id = "!foo:example.com"
- user_id = "@user1235:test"
+ # Create a user to receive notifications and send receipts.
+ user_id = self.register_user("user1235", "pass")
+ token = self.login("user1235", "pass")
+
+ # And another users to send events.
+ other_id = self.register_user("other", "pass")
+ other_token = self.login("other", "pass")
+
+ # Create a room and put both users in it.
+ room_id = self.helper.create_room_as(user_id, tok=token)
+ self.helper.join(room_id, other_id, tok=other_token)
- last_read_stream_ordering = [0]
+ last_event_id: str
- def _assert_counts(noitf_count: int, highlight_count: int) -> None:
+ def _assert_counts(
+ noitf_count: int, unread_count: int, highlight_count: int
+ ) -> None:
counts = self.get_success(
self.store.db_pool.runInteraction(
- "",
- self.store._get_unread_counts_by_pos_txn,
+ "get-unread-counts",
+ self.store._get_unread_counts_by_receipt_txn,
room_id,
user_id,
- last_read_stream_ordering[0],
)
)
self.assertEqual(
counts,
NotifCounts(
notify_count=noitf_count,
- unread_count=0, # Unread counts are tested in the sync tests.
+ unread_count=unread_count,
highlight_count=highlight_count,
),
)
- def _inject_actions(stream: int, action: list) -> None:
- event = Mock()
- event.room_id = room_id
- event.event_id = f"$test{stream}:example.com"
- event.internal_metadata.stream_ordering = stream
- event.internal_metadata.is_outlier.return_value = False
- event.depth = stream
-
- self.store._events_stream_cache.entity_has_changed(room_id, stream)
-
- self.get_success(
- self.store.db_pool.simple_insert(
- table="events",
- values={
- "stream_ordering": stream,
- "topological_ordering": stream,
- "type": "m.room.message",
- "room_id": room_id,
- "processed": True,
- "outlier": False,
- "event_id": event.event_id,
- },
- )
+ def _create_event(highlight: bool = False) -> str:
+ result = self.helper.send_event(
+ room_id,
+ type="m.room.message",
+ content={"msgtype": "m.text", "body": user_id if highlight else "msg"},
+ tok=other_token,
)
+ nonlocal last_event_id
+ last_event_id = result["event_id"]
+ return last_event_id
- self.get_success(
- self.store.add_push_actions_to_staging(
- event.event_id,
- {user_id: action},
- False,
- )
- )
- self.get_success(
- self.store.db_pool.runInteraction(
- "",
- self.persist_events_store._set_push_actions_for_event_and_users_txn,
- [(event, None)],
- [(event, None)],
- )
- )
-
- def _rotate(stream: int) -> None:
- self.get_success(
- self.store.db_pool.runInteraction(
- "rotate-receipts", self.store._handle_new_receipts_for_notifs_txn
- )
- )
-
- self.get_success(
- self.store.db_pool.runInteraction(
- "rotate-notifs", self.store._rotate_notifs_before_txn, stream
- )
- )
-
- def _mark_read(stream: int, depth: int) -> None:
- last_read_stream_ordering[0] = stream
+ def _rotate() -> None:
+ self.get_success(self.store._rotate_notifs())
+ def _mark_read(event_id: str) -> None:
self.get_success(
self.store.insert_receipt(
room_id,
"m.read",
user_id=user_id,
- event_ids=[f"$test{stream}:example.com"],
+ event_ids=[event_id],
data={},
)
)
- _assert_counts(0, 0)
- _inject_actions(1, PlAIN_NOTIF)
- _assert_counts(1, 0)
- _rotate(1)
- _assert_counts(1, 0)
+ _assert_counts(0, 0, 0)
+ _create_event()
+ _assert_counts(1, 1, 0)
+ _rotate()
+ _assert_counts(1, 1, 0)
- _inject_actions(3, PlAIN_NOTIF)
- _assert_counts(2, 0)
- _rotate(3)
- _assert_counts(2, 0)
+ event_id = _create_event()
+ _assert_counts(2, 2, 0)
+ _rotate()
+ _assert_counts(2, 2, 0)
- _inject_actions(5, PlAIN_NOTIF)
- _mark_read(3, 3)
- _assert_counts(1, 0)
+ _create_event()
+ _mark_read(event_id)
+ _assert_counts(1, 1, 0)
- _mark_read(5, 5)
- _assert_counts(0, 0)
+ _mark_read(last_event_id)
+ _assert_counts(0, 0, 0)
- _inject_actions(6, PlAIN_NOTIF)
- _rotate(6)
- _assert_counts(1, 0)
-
- self.get_success(
- self.store.db_pool.simple_delete(
- table="event_push_actions", keyvalues={"1": 1}, desc=""
- )
- )
+ _create_event()
+ _rotate()
+ _assert_counts(1, 1, 0)
- _assert_counts(1, 0)
+ # Delete old event push actions, this should not affect the (summarised) count.
+ self.get_success(self.store._remove_old_push_actions_that_have_rotated())
+ _assert_counts(1, 1, 0)
- _mark_read(6, 6)
- _assert_counts(0, 0)
+ _mark_read(last_event_id)
+ _assert_counts(0, 0, 0)
- _inject_actions(8, HIGHLIGHT)
- _assert_counts(1, 1)
- _rotate(8)
- _assert_counts(1, 1)
+ event_id = _create_event(True)
+ _assert_counts(1, 1, 1)
+ _rotate()
+ _assert_counts(1, 1, 1)
# Check that adding another notification and rotating after highlight
# works.
- _inject_actions(10, PlAIN_NOTIF)
- _rotate(10)
- _assert_counts(2, 1)
+ _create_event()
+ _rotate()
+ _assert_counts(2, 2, 1)
# Check that sending read receipts at different points results in the
# right counts.
- _mark_read(8, 8)
- _assert_counts(1, 0)
- _mark_read(10, 10)
- _assert_counts(0, 0)
-
- _inject_actions(11, HIGHLIGHT)
- _assert_counts(1, 1)
- _mark_read(11, 11)
- _assert_counts(0, 0)
- _rotate(11)
- _assert_counts(0, 0)
+ _mark_read(event_id)
+ _assert_counts(1, 1, 0)
+ _mark_read(last_event_id)
+ _assert_counts(0, 0, 0)
+
+ _create_event(True)
+ _assert_counts(1, 1, 1)
+ _mark_read(last_event_id)
+ _assert_counts(0, 0, 0)
+ _rotate()
+ _assert_counts(0, 0, 0)
def test_find_first_stream_ordering_after_ts(self) -> None:
def add_event(so: int, ts: int) -> None:
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 8dfaa0559b..9c1182ed16 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -115,6 +115,6 @@ class PurgeTests(HomeserverTestCase):
)
# The events aren't found.
- self.store._invalidate_get_event_cache(create_event.event_id)
+ self.store._invalidate_local_get_event_cache(create_event.event_id)
self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 1218786d79..240b02cb9f 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -23,7 +23,6 @@ from synapse.util import Clock
from tests import unittest
from tests.server import TestHomeServer
-from tests.test_utils import event_injection
class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
@@ -110,60 +109,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# It now knows about Charlie's server.
self.assertEqual(self.store._known_servers_count, 2)
- def test_get_joined_users_from_context(self) -> None:
- room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
- bob_event = self.get_success(
- event_injection.inject_member_event(
- self.hs, room, self.u_bob, Membership.JOIN
- )
- )
-
- # first, create a regular event
- event, context = self.get_success(
- event_injection.create_event(
- self.hs,
- room_id=room,
- sender=self.u_alice,
- prev_event_ids=[bob_event.event_id],
- type="m.test.1",
- content={},
- )
- )
-
- users = self.get_success(
- self.store.get_joined_users_from_context(event, context)
- )
- self.assertEqual(users.keys(), {self.u_alice, self.u_bob})
-
- # Regression test for #7376: create a state event whose key matches bob's
- # user_id, but which is *not* a membership event, and persist that; then check
- # that `get_joined_users_from_context` returns the correct users for the next event.
- non_member_event = self.get_success(
- event_injection.inject_event(
- self.hs,
- room_id=room,
- sender=self.u_bob,
- prev_event_ids=[bob_event.event_id],
- type="m.test.2",
- state_key=self.u_bob,
- content={},
- )
- )
- event, context = self.get_success(
- event_injection.create_event(
- self.hs,
- room_id=room,
- sender=self.u_alice,
- prev_event_ids=[non_member_event.event_id],
- type="m.test.3",
- content={},
- )
- )
- users = self.get_success(
- self.store.get_joined_users_from_context(event, context)
- )
- self.assertEqual(users.keys(), {self.u_alice, self.u_bob})
-
def test__null_byte_in_display_name_properly_handled(self) -> None:
room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 371cd201af..e42d7b9ba0 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -19,7 +19,7 @@ from parameterized import parameterized
from synapse import event_auth
from synapse.api.constants import EventContentFields
-from synapse.api.errors import AuthError
+from synapse.api.errors import AuthError, SynapseError
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
from synapse.events import EventBase, make_event_from_dict
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
@@ -689,6 +689,45 @@ class EventAuthTestCase(unittest.TestCase):
auth_events.values(),
)
+ def test_room_v10_rejects_string_power_levels(self) -> None:
+ pl_event_content = {"users_default": "42"}
+ pl_event = make_event_from_dict(
+ {
+ "room_id": TEST_ROOM_ID,
+ **_maybe_get_event_id_dict_for_room_version(RoomVersions.V10),
+ "type": "m.room.power_levels",
+ "sender": "@test:test.com",
+ "state_key": "",
+ "content": pl_event_content,
+ "signatures": {"test.com": {"ed25519:0": "some9signature"}},
+ },
+ room_version=RoomVersions.V10,
+ )
+
+ pl_event2_content = {"events": {"m.room.name": "42", "m.room.power_levels": 42}}
+ pl_event2 = make_event_from_dict(
+ {
+ "room_id": TEST_ROOM_ID,
+ **_maybe_get_event_id_dict_for_room_version(RoomVersions.V10),
+ "type": "m.room.power_levels",
+ "sender": "@test:test.com",
+ "state_key": "",
+ "content": pl_event2_content,
+ "signatures": {"test.com": {"ed25519:0": "some9signature"}},
+ },
+ room_version=RoomVersions.V10,
+ )
+
+ with self.assertRaises(SynapseError):
+ event_auth._check_power_levels(
+ pl_event.room_version, pl_event, {("fake_type", "fake_key"): pl_event2}
+ )
+
+ with self.assertRaises(SynapseError):
+ event_auth._check_power_levels(
+ pl_event.room_version, pl_event2, {("fake_type", "fake_key"): pl_event}
+ )
+
# helpers for making events
TEST_DOMAIN = "example.com"
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 0cbef70bfa..779fad1f63 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -81,12 +81,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
self.handler = self.homeserver.get_federation_handler()
federation_event_handler = self.homeserver.get_federation_event_handler()
- async def _check_event_auth(
- origin,
- event,
- context,
- ):
- return context
+ async def _check_event_auth(origin, event, context):
+ pass
federation_event_handler._check_event_auth = _check_event_auth
self.client = self.homeserver.get_federation_client()
diff --git a/tests/test_server.py b/tests/test_server.py
index fc4bce899c..2fe4411401 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -231,7 +231,7 @@ class OptionsResourceTests(unittest.TestCase):
parse_listener_def({"type": "http", "port": 0}),
self.resource,
"1.0",
- max_request_body_size=1234,
+ max_request_body_size=4096,
reactor=self.reactor,
)
diff --git a/tests/test_state.py b/tests/test_state.py
index 6ca8d8f21d..bafd6d1750 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -21,7 +21,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext
-from synapse.state import StateHandler, StateResolutionHandler
+from synapse.state import StateHandler, StateResolutionHandler, _make_state_cache_entry
from synapse.util import Clock
from synapse.util.macaroons import MacaroonGenerator
@@ -99,6 +99,10 @@ class _DummyStore:
state_group = self._next_group
self._next_group += 1
+ if current_state_ids is None:
+ current_state_ids = dict(self._group_to_state[prev_group])
+ current_state_ids.update(delta_ids)
+
self._group_to_state[state_group] = dict(current_state_ids)
return state_group
@@ -760,3 +764,43 @@ class StateTestCase(unittest.TestCase):
result = yield defer.ensureDeferred(self.state.compute_event_context(event))
return result
+
+ def test_make_state_cache_entry(self):
+ "Test that calculating a prev_group and delta is correct"
+
+ new_state = {
+ ("a", ""): "E",
+ ("b", ""): "E",
+ ("c", ""): "E",
+ ("d", ""): "E",
+ }
+
+ # old_state_1 has fewer differences to new_state than old_state_2, but
+ # the delta involves deleting a key, which isn't allowed in the deltas,
+ # so we should pick old_state_2 as the prev_group.
+
+ # `old_state_1` has two differences: `a` and `e`
+ old_state_1 = {
+ ("a", ""): "F",
+ ("b", ""): "E",
+ ("c", ""): "E",
+ ("d", ""): "E",
+ ("e", ""): "E",
+ }
+
+ # `old_state_2` has three differences: `a`, `c` and `d`
+ old_state_2 = {
+ ("a", ""): "F",
+ ("b", ""): "E",
+ ("c", ""): "F",
+ ("d", ""): "F",
+ }
+
+ entry = _make_state_cache_entry(new_state, {1: old_state_1, 2: old_state_2})
+
+ self.assertEqual(entry.prev_group, 2)
+
+ # There are three changes from `old_state_2` to `new_state`
+ self.assertEqual(
+ entry.delta_ids, {("a", ""): "E", ("c", ""): "E", ("d", ""): "E"}
+ )
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index 37fada5c53..d3c13cf14c 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactorClock
@@ -51,7 +50,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
def test_ui_auth(self):
# Do a UI auth request
- request_data = json.dumps({"username": "kermit", "password": "monkey"})
+ request_data = {"username": "kermit", "password": "monkey"}
channel = self.make_request(b"POST", self.url, request_data)
self.assertEqual(channel.result["code"], b"401", channel.result)
@@ -82,16 +81,14 @@ class TermsTestCase(unittest.HomeserverTestCase):
self.assertDictContainsSubset(channel.json_body["params"], expected_params)
# We have to complete the dummy auth stage before completing the terms stage
- request_data = json.dumps(
- {
- "username": "kermit",
- "password": "monkey",
- "auth": {
- "session": channel.json_body["session"],
- "type": "m.login.dummy",
- },
- }
- )
+ request_data = {
+ "username": "kermit",
+ "password": "monkey",
+ "auth": {
+ "session": channel.json_body["session"],
+ "type": "m.login.dummy",
+ },
+ }
self.registration_handler.check_username = Mock(return_value=True)
@@ -102,16 +99,14 @@ class TermsTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"401", channel.result)
# Finish the UI auth for terms
- request_data = json.dumps(
- {
- "username": "kermit",
- "password": "monkey",
- "auth": {
- "session": channel.json_body["session"],
- "type": "m.login.terms",
- },
- }
- )
+ request_data = {
+ "username": "kermit",
+ "password": "monkey",
+ "auth": {
+ "session": channel.json_body["session"],
+ "type": "m.login.terms",
+ },
+ }
channel = self.make_request(b"POST", self.url, request_data)
# We're interested in getting a response that looks like a successful
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index f338af6c36..c385b2f8d4 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -272,7 +272,7 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
"state_key": "@user:test",
"content": {"membership": "invite"},
}
- self.add_hashes_and_signatures(invite_pdu)
+ self.add_hashes_and_signatures_from_other_server(invite_pdu)
invite_event_id = make_event_from_dict(invite_pdu, RoomVersions.V9).event_id
self.get_success(
diff --git a/tests/unittest.py b/tests/unittest.py
index c645dd3563..66ce92f4a6 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -16,7 +16,6 @@
import gc
import hashlib
import hmac
-import json
import logging
import secrets
import time
@@ -285,7 +284,7 @@ class HomeserverTestCase(TestCase):
config=self.hs.config.server.listeners[0],
resource=self.resource,
server_version_string="1",
- max_request_body_size=1234,
+ max_request_body_size=4096,
reactor=self.reactor,
)
@@ -619,20 +618,16 @@ class HomeserverTestCase(TestCase):
want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str)
want_mac_digest = want_mac.hexdigest()
- body = json.dumps(
- {
- "nonce": nonce,
- "username": username,
- "displayname": displayname,
- "password": password,
- "admin": admin,
- "mac": want_mac_digest,
- "inhibit_login": True,
- }
- )
- channel = self.make_request(
- "POST", "/_synapse/admin/v1/register", body.encode("utf8")
- )
+ body = {
+ "nonce": nonce,
+ "username": username,
+ "displayname": displayname,
+ "password": password,
+ "admin": admin,
+ "mac": want_mac_digest,
+ "inhibit_login": True,
+ }
+ channel = self.make_request("POST", "/_synapse/admin/v1/register", body)
self.assertEqual(channel.code, 200, channel.json_body)
user_id = channel.json_body["user_id"]
@@ -676,9 +671,7 @@ class HomeserverTestCase(TestCase):
custom_headers: Optional[Iterable[CustomHeaderType]] = None,
) -> str:
"""
- Log in a user, and get an access token. Requires the Login API be
- registered.
-
+ Log in a user, and get an access token. Requires the Login API be registered.
"""
body = {"type": "m.login.password", "user": username, "password": password}
if device_id:
@@ -687,7 +680,7 @@ class HomeserverTestCase(TestCase):
channel = self.make_request(
"POST",
"/_matrix/client/r0/login",
- json.dumps(body).encode("utf8"),
+ body,
custom_headers=custom_headers,
)
self.assertEqual(channel.code, 200, channel.result)
@@ -780,7 +773,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
verify_key_id,
FetchKeyResult(
verify_key=verify_key,
- valid_until_ts=clock.time_msec() + 1000,
+ valid_until_ts=clock.time_msec() + 10000,
),
)
],
@@ -838,7 +831,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
client_ip=client_ip,
)
- def add_hashes_and_signatures(
+ def add_hashes_and_signatures_from_other_server(
self,
event_dict: JsonDict,
room_version: RoomVersion = KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
diff --git a/tests/utils.py b/tests/utils.py
index 424cc4c2a0..d2c6d1e852 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -167,6 +167,7 @@ def default_config(
"local": {"per_second": 10000, "burst_count": 10000},
"remote": {"per_second": 10000, "burst_count": 10000},
},
+ "rc_joins_per_room": {"per_second": 10000, "burst_count": 10000},
"rc_invites": {
"per_room": {"per_second": 10000, "burst_count": 10000},
"per_user": {"per_second": 10000, "burst_count": 10000},
|