diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index bc75ddd3e9..e0f363555b 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -19,6 +19,7 @@ import pymacaroons
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.auth import Auth
+from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import UserTypes
from synapse.api.errors import (
AuthError,
@@ -49,7 +50,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
# AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
- self.auth_blocking = self.auth._auth_blocking
+ self.auth_blocking = AuthBlocking(hs)
self.test_user = "@foo:bar"
self.test_token = b"_test_token_"
@@ -283,10 +284,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
TokenLookupResult(
user_id="@baldrick:matrix.org",
device_id="device",
+ token_id=5,
token_owner="@admin:matrix.org",
+ token_used=True,
)
)
self.store.insert_client_ip = simple_async_mock(None)
+ self.store.mark_access_token_as_used = simple_async_mock(None)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
@@ -300,10 +304,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
TokenLookupResult(
user_id="@baldrick:matrix.org",
device_id="device",
+ token_id=5,
token_owner="@admin:matrix.org",
+ token_used=True,
)
)
self.store.insert_client_ip = simple_async_mock(None)
+ self.store.mark_access_token_as_used = simple_async_mock(None)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
@@ -312,9 +319,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(self.store.insert_client_ip.call_count, 2)
def test_get_user_from_macaroon(self):
- self.store.get_user_by_access_token = simple_async_mock(
- TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
- )
+ self.store.get_user_by_access_token = simple_async_mock(None)
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
@@ -322,17 +327,14 @@ class AuthTestCase(unittest.HomeserverTestCase):
identifier="key",
key=self.hs.config.key.macaroon_secret_key,
)
+ # "Legacy" macaroons should not work for regular users not in the database
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
- user_info = self.get_success(
- self.auth.get_user_by_access_token(macaroon.serialize())
+ serialized = macaroon.serialize()
+ self.get_failure(
+ self.auth.get_user_by_access_token(serialized), InvalidClientTokenError
)
- self.assertEqual(user_id, user_info.user_id)
-
- # TODO: device_id should come from the macaroon, but currently comes
- # from the db.
- self.assertEqual(user_info.device_id, "device")
def test_get_guest_user_from_macaroon(self):
self.store.get_user_by_id = simple_async_mock({"is_guest": True})
@@ -351,7 +353,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
serialized = macaroon.serialize()
user_info = self.get_success(self.auth.get_user_by_access_token(serialized))
- self.assertEqual(user_id, user_info.user_id)
+ self.assertEqual(user_id, user_info.user.to_string())
self.assertTrue(user_info.is_guest)
self.store.get_user_by_id.assert_called_with(user_id)
@@ -362,20 +364,22 @@ class AuthTestCase(unittest.HomeserverTestCase):
small_number_of_users = 1
# Ensure no error thrown
- self.get_success(self.auth.check_auth_blocking())
+ self.get_success(self.auth_blocking.check_auth_blocking())
self.auth_blocking._limit_usage_by_mau = True
self.store.get_monthly_active_count = simple_async_mock(lots_of_users)
- e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+ e = self.get_failure(
+ self.auth_blocking.check_auth_blocking(), ResourceLimitError
+ )
self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEqual(e.value.code, 403)
# Ensure does not throw an error
self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
- self.get_success(self.auth.check_auth_blocking())
+ self.get_success(self.auth_blocking.check_auth_blocking())
def test_blocking_mau__depending_on_user_type(self):
self.auth_blocking._max_mau_value = 50
@@ -383,15 +387,18 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_monthly_active_count = simple_async_mock(100)
# Support users allowed
- self.get_success(self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT))
+ self.get_success(
+ self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT)
+ )
self.store.get_monthly_active_count = simple_async_mock(100)
# Bots not allowed
self.get_failure(
- self.auth.check_auth_blocking(user_type=UserTypes.BOT), ResourceLimitError
+ self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT),
+ ResourceLimitError,
)
self.store.get_monthly_active_count = simple_async_mock(100)
# Real users not allowed
- self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+ self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self):
self.auth_blocking._max_mau_value = 50
@@ -419,7 +426,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
app_service=appservice,
authenticated_entity="@appservice:server",
)
- self.get_success(self.auth.check_auth_blocking(requester=requester))
+ self.get_success(self.auth_blocking.check_auth_blocking(requester=requester))
def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self):
self.auth_blocking._max_mau_value = 50
@@ -448,7 +455,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
authenticated_entity="@appservice:server",
)
self.get_failure(
- self.auth.check_auth_blocking(requester=requester), ResourceLimitError
+ self.auth_blocking.check_auth_blocking(requester=requester),
+ ResourceLimitError,
)
def test_reserved_threepid(self):
@@ -459,18 +467,21 @@ class AuthTestCase(unittest.HomeserverTestCase):
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
self.auth_blocking._mau_limits_reserved_threepids = [threepid]
- self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+ self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
self.get_failure(
- self.auth.check_auth_blocking(threepid=unknown_threepid), ResourceLimitError
+ self.auth_blocking.check_auth_blocking(threepid=unknown_threepid),
+ ResourceLimitError,
)
- self.get_success(self.auth.check_auth_blocking(threepid=threepid))
+ self.get_success(self.auth_blocking.check_auth_blocking(threepid=threepid))
def test_hs_disabled(self):
self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
- e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+ e = self.get_failure(
+ self.auth_blocking.check_auth_blocking(), ResourceLimitError
+ )
self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEqual(e.value.code, 403)
@@ -485,7 +496,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
- e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+ e = self.get_failure(
+ self.auth_blocking.check_auth_blocking(), ResourceLimitError
+ )
self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEqual(e.value.code, 403)
@@ -495,4 +508,4 @@ class AuthTestCase(unittest.HomeserverTestCase):
user = "@user:server"
self.auth_blocking._server_notices_mxid = user
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
- self.get_success(self.auth.check_auth_blocking(user))
+ self.get_success(self.auth_blocking.check_auth_blocking(user))
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index f661a9ff8e..c86f783c5b 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -246,7 +246,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
self.assertTrue(allowed)
self.assertEqual(10.0, time_allowed)
- # Test that, after doing these 3 actions, we can't do any more action without
+ # Test that, after doing these 3 actions, we can't do any more actions without
# waiting.
allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(None, key="test_id", n_actions=1, _time_now_s=0)
@@ -254,7 +254,8 @@ class TestRatelimiter(unittest.HomeserverTestCase):
self.assertFalse(allowed)
self.assertEqual(10.0, time_allowed)
- # Test that after waiting we can do only 1 action.
+ # Test that after waiting we would be able to do only 1 action.
+ # Note that we don't actually do it (update=False) here.
allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(
None,
@@ -265,23 +266,125 @@ class TestRatelimiter(unittest.HomeserverTestCase):
)
)
self.assertTrue(allowed)
- # The time allowed is the current time because we could still repeat the action
- # once.
- self.assertEqual(10.0, time_allowed)
+ # We would be able to do the 5th action at t=20.
+ self.assertEqual(20.0, time_allowed)
+ # Attempt (but fail) to perform TWO actions at t=10.
+ # Those would be the 4th and 5th actions.
allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(None, key="test_id", n_actions=2, _time_now_s=10)
)
self.assertFalse(allowed)
- # The time allowed doesn't change despite allowed being False because, while we
- # don't allow 2 actions, we could still do 1.
+ # The returned time allowed for the next action is now even though we weren't
+ # allowed to perform the action because whilst we don't allow 2 actions,
+ # we could still do 1.
self.assertEqual(10.0, time_allowed)
- # Test that after waiting a bit more we can do 2 actions.
+ # Test that after waiting until t=20, we can do perform 2 actions.
+ # These are the 4th and 5th actions.
allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(None, key="test_id", n_actions=2, _time_now_s=20)
)
self.assertTrue(allowed)
- # The time allowed is the current time because we could still repeat the action
- # once.
- self.assertEqual(20.0, time_allowed)
+ # We would be able to do the 6th action at t=30.
+ self.assertEqual(30.0, time_allowed)
+
+ def test_rate_limit_burst_only_given_once(self) -> None:
+ """
+ Regression test against a bug that meant that you could build up
+ extra tokens by timing requests.
+ """
+ limiter = Ratelimiter(
+ store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
+ )
+
+ def consume_at(time: float) -> bool:
+ success, _ = self.get_success_or_raise(
+ limiter.can_do_action(requester=None, key="a", _time_now_s=time)
+ )
+ return success
+
+ # Use all our 3 burst tokens
+ self.assertTrue(consume_at(0.0))
+ self.assertTrue(consume_at(0.1))
+ self.assertTrue(consume_at(0.2))
+
+ # Wait to recover 1 token (10 seconds at 0.1 Hz).
+ self.assertTrue(consume_at(10.1))
+
+ # Check that we get rate limited after using that token.
+ self.assertFalse(consume_at(11.1))
+
+ def test_record_action_which_doesnt_fill_bucket(self) -> None:
+ limiter = Ratelimiter(
+ store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
+ )
+
+ # Observe two actions, leaving room in the bucket for one more.
+ limiter.record_action(requester=None, key="a", n_actions=2, _time_now_s=0.0)
+
+ # We should be able to take a new action now.
+ success, _ = self.get_success_or_raise(
+ limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
+ )
+ self.assertTrue(success)
+
+ # ... but not two.
+ success, _ = self.get_success_or_raise(
+ limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
+ )
+ self.assertFalse(success)
+
+ def test_record_action_which_fills_bucket(self) -> None:
+ limiter = Ratelimiter(
+ store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
+ )
+
+ # Observe three actions, filling up the bucket.
+ limiter.record_action(requester=None, key="a", n_actions=3, _time_now_s=0.0)
+
+ # We should be unable to take a new action now.
+ success, _ = self.get_success_or_raise(
+ limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
+ )
+ self.assertFalse(success)
+
+ # If we wait 10 seconds to leak a token, we should be able to take one action...
+ success, _ = self.get_success_or_raise(
+ limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
+ )
+ self.assertTrue(success)
+
+ # ... but not two.
+ success, _ = self.get_success_or_raise(
+ limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
+ )
+ self.assertFalse(success)
+
+ def test_record_action_which_overfills_bucket(self) -> None:
+ limiter = Ratelimiter(
+ store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
+ )
+
+ # Observe four actions, exceeding the bucket.
+ limiter.record_action(requester=None, key="a", n_actions=4, _time_now_s=0.0)
+
+ # We should be prevented from taking a new action now.
+ success, _ = self.get_success_or_raise(
+ limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
+ )
+ self.assertFalse(success)
+
+ # If we wait 10 seconds to leak a token, we should be unable to take an action
+ # because the bucket is still full.
+ success, _ = self.get_success_or_raise(
+ limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
+ )
+ self.assertFalse(success)
+
+ # But after another 10 seconds we leak a second token, giving us room for
+ # action.
+ success, _ = self.get_success_or_raise(
+ limiter.can_do_action(requester=None, key="a", _time_now_s=20.0)
+ )
+ self.assertTrue(success)
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index ffc3012a86..685a9a6d52 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -141,10 +141,6 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
hs = self.setup_test_homeserver(
federation_transport_client=fed_transport_client,
)
- # Load the modules into the homeserver
- module_api = hs.get_module_api()
- for module, config in hs.config.modules.loaded_modules:
- module(config=config, api=module_api)
load_legacy_presence_router(hs)
diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py
index 268a48d7ba..50e376f695 100644
--- a/tests/federation/test_federation_client.py
+++ b/tests/federation/test_federation_client.py
@@ -22,6 +22,7 @@ from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
@@ -38,42 +39,46 @@ class FederationClientTest(FederatingHomeserverTestCase):
self._mock_agent = mock.create_autospec(twisted.web.client.Agent, spec_set=True)
homeserver.get_federation_http_client().agent = self._mock_agent
- def test_get_room_state(self):
- creator = f"@creator:{self.OTHER_SERVER_NAME}"
- test_room_id = "!room_id"
+ # Move clock up to somewhat realistic time so the PDU destination retry
+ # works (`now` needs to be larger than `0 + PDU_RETRY_TIME_MS`).
+ self.reactor.advance(1000000000)
+
+ self.creator = f"@creator:{self.OTHER_SERVER_NAME}"
+ self.test_room_id = "!room_id"
+ def test_get_room_state(self):
# mock up some events to use in the response.
# In real life, these would have things in `prev_events` and `auth_events`, but that's
# a bit annoying to mock up, and the code under test doesn't care, so we don't bother.
- create_event_dict = self.add_hashes_and_signatures(
+ create_event_dict = self.add_hashes_and_signatures_from_other_server(
{
- "room_id": test_room_id,
+ "room_id": self.test_room_id,
"type": "m.room.create",
"state_key": "",
- "sender": creator,
- "content": {"creator": creator},
+ "sender": self.creator,
+ "content": {"creator": self.creator},
"prev_events": [],
"auth_events": [],
"origin_server_ts": 500,
}
)
- member_event_dict = self.add_hashes_and_signatures(
+ member_event_dict = self.add_hashes_and_signatures_from_other_server(
{
- "room_id": test_room_id,
+ "room_id": self.test_room_id,
"type": "m.room.member",
- "sender": creator,
- "state_key": creator,
+ "sender": self.creator,
+ "state_key": self.creator,
"content": {"membership": "join"},
"prev_events": [],
"auth_events": [],
"origin_server_ts": 600,
}
)
- pl_event_dict = self.add_hashes_and_signatures(
+ pl_event_dict = self.add_hashes_and_signatures_from_other_server(
{
- "room_id": test_room_id,
+ "room_id": self.test_room_id,
"type": "m.room.power_levels",
- "sender": creator,
+ "sender": self.creator,
"state_key": "",
"content": {},
"prev_events": [],
@@ -102,8 +107,8 @@ class FederationClientTest(FederatingHomeserverTestCase):
# now fire off the request
state_resp, auth_resp = self.get_success(
self.hs.get_federation_client().get_room_state(
- "yet_another_server",
- test_room_id,
+ "yet.another.server",
+ self.test_room_id,
"event_id",
RoomVersions.V9,
)
@@ -112,7 +117,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
# check the right call got made to the agent
self._mock_agent.request.assert_called_once_with(
b"GET",
- b"matrix://yet_another_server/_matrix/federation/v1/state/%21room_id?event_id=event_id",
+ b"matrix://yet.another.server/_matrix/federation/v1/state/%21room_id?event_id=event_id",
headers=mock.ANY,
bodyProducer=None,
)
@@ -130,6 +135,102 @@ class FederationClientTest(FederatingHomeserverTestCase):
["m.room.create", "m.room.member", "m.room.power_levels"],
)
+ def test_get_pdu_returns_nothing_when_event_does_not_exist(self):
+ """No event should be returned when the event does not exist"""
+ remote_pdu = self.get_success(
+ self.hs.get_federation_client().get_pdu(
+ ["yet.another.server"],
+ "event_should_not_exist",
+ RoomVersions.V9,
+ )
+ )
+ self.assertEqual(remote_pdu, None)
+
+ def test_get_pdu(self):
+ """Test to make sure an event is returned by `get_pdu()`"""
+ self._get_pdu_once()
+
+ def test_get_pdu_event_from_cache_is_pristine(self):
+ """Test that modifications made to events returned by `get_pdu()`
+ do not propagate back to to the internal cache (events returned should
+ be a copy).
+ """
+
+ # Get the PDU in the cache
+ remote_pdu = self._get_pdu_once()
+
+ # Modify the the event reference.
+ # This change should not make it back to the `_get_pdu_cache`.
+ remote_pdu.internal_metadata.outlier = True
+
+ # Get the event again. This time it should read it from cache.
+ remote_pdu2 = self.get_success(
+ self.hs.get_federation_client().get_pdu(
+ ["yet.another.server"],
+ remote_pdu.event_id,
+ RoomVersions.V9,
+ )
+ )
+
+ # Sanity check that we are working against the same event
+ self.assertEqual(remote_pdu.event_id, remote_pdu2.event_id)
+
+ # Make sure the event does not include modification from earlier
+ self.assertIsNotNone(remote_pdu2)
+ self.assertEqual(remote_pdu2.internal_metadata.outlier, False)
+
+ def _get_pdu_once(self) -> EventBase:
+ """Retrieve an event via `get_pdu()` and assert that an event was returned.
+ Also used to prime the cache for subsequent test logic.
+ """
+ message_event_dict = self.add_hashes_and_signatures_from_other_server(
+ {
+ "room_id": self.test_room_id,
+ "type": "m.room.message",
+ "sender": self.creator,
+ "state_key": "",
+ "content": {},
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": 700,
+ "depth": 10,
+ }
+ )
+
+ # mock up the response, and have the agent return it
+ self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
+ _mock_response(
+ {
+ "origin": "yet.another.server",
+ "origin_server_ts": 900,
+ "pdus": [
+ message_event_dict,
+ ],
+ }
+ )
+ )
+
+ remote_pdu = self.get_success(
+ self.hs.get_federation_client().get_pdu(
+ ["yet.another.server"],
+ "event_id",
+ RoomVersions.V9,
+ )
+ )
+
+ # check the right call got made to the agent
+ self._mock_agent.request.assert_called_once_with(
+ b"GET",
+ b"matrix://yet.another.server/_matrix/federation/v1/event/event_id",
+ headers=mock.ANY,
+ bodyProducer=None,
+ )
+
+ self.assertIsNotNone(remote_pdu)
+ self.assertEqual(remote_pdu.internal_metadata.outlier, False)
+
+ return remote_pdu
+
def _mock_response(resp: JsonDict):
body = json.dumps(resp).encode("utf-8")
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 01a1db6115..a5aa500ef8 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -173,17 +173,24 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
return c
def prepare(self, reactor, clock, hs):
- # stub out `get_rooms_for_user` and `get_users_in_room` so that the
+ test_room_id = "!room:host1"
+
+ # stub out `get_rooms_for_user` and `get_current_hosts_in_room` so that the
# server thinks the user shares a room with `@user2:host2`
def get_rooms_for_user(user_id):
- return defer.succeed({"!room:host1"})
+ return defer.succeed({test_room_id})
hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user
- def get_users_in_room(room_id):
- return defer.succeed({"@user2:host2"})
+ async def get_current_hosts_in_room(room_id):
+ if room_id == test_room_id:
+ return ["host2"]
+
+ # TODO: We should fail the test when we encounter an unxpected room ID.
+ # We can't just use `self.fail(...)` here because the app code is greedy
+ # with `Exception` and will catch it before the test can see it.
- hs.get_datastores().main.get_users_in_room = get_users_in_room
+ hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room
# whenever send_transaction is called, record the edu data
self.edus = []
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 413b3c9426..3a6ef221ae 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from http import HTTPStatus
from parameterized import parameterized
@@ -20,7 +21,6 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.config.server import DEFAULT_ROOM_VERSION
-from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events import make_event_from_dict
from synapse.federation.federation_server import server_matches_acl_event
from synapse.rest import admin
@@ -59,7 +59,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
"/_matrix/federation/v1/get_missing_events/%s" % (room_1,),
query_content,
)
- self.assertEqual(400, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON")
@@ -120,7 +120,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/v1/state/%s?event_id=xyz" % (room_1,)
)
- self.assertEqual(403, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@@ -148,13 +148,13 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
tok2 = self.login("fozzie", "bear")
self.helper.join(self._room_id, second_member_user_id, tok=tok2)
- def _make_join(self, user_id) -> JsonDict:
+ def _make_join(self, user_id: str) -> JsonDict:
channel = self.make_signed_federation_request(
"GET",
f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}"
f"?ver={DEFAULT_ROOM_VERSION}",
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
return channel.json_body
def test_send_join(self):
@@ -163,18 +163,16 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
join_result = self._make_join(joining_user)
join_event_dict = join_result["event"]
- add_hashes_and_signatures(
- KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
+ self.add_hashes_and_signatures_from_other_server(
join_event_dict,
- signature_name=self.OTHER_SERVER_NAME,
- signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
+ KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
)
channel = self.make_signed_federation_request(
"PUT",
f"/_matrix/federation/v2/send_join/{self._room_id}/x",
content=join_event_dict,
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# we should get complete room state back
returned_state = [
@@ -220,18 +218,16 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
join_result = self._make_join(joining_user)
join_event_dict = join_result["event"]
- add_hashes_and_signatures(
- KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
+ self.add_hashes_and_signatures_from_other_server(
join_event_dict,
- signature_name=self.OTHER_SERVER_NAME,
- signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
+ KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
)
channel = self.make_signed_federation_request(
"PUT",
f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true",
content=join_event_dict,
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# expect a reduced room state
returned_state = [
@@ -264,6 +260,67 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
+ @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 3}})
+ def test_make_join_respects_room_join_rate_limit(self) -> None:
+ # In the test setup, two users join the room. Since the rate limiter burst
+ # count is 3, a new make_join request to the room should be accepted.
+
+ joining_user = "@ronniecorbett:" + self.OTHER_SERVER_NAME
+ self._make_join(joining_user)
+
+ # Now have a new local user join the room. This saturates the rate limiter
+ # bucket, so the next make_join should be denied.
+ new_local_user = self.register_user("animal", "animal")
+ token = self.login("animal", "animal")
+ self.helper.join(self._room_id, new_local_user, tok=token)
+
+ joining_user = "@ronniebarker:" + self.OTHER_SERVER_NAME
+ channel = self.make_signed_federation_request(
+ "GET",
+ f"/_matrix/federation/v1/make_join/{self._room_id}/{joining_user}"
+ f"?ver={DEFAULT_ROOM_VERSION}",
+ )
+ self.assertEqual(channel.code, HTTPStatus.TOO_MANY_REQUESTS, channel.json_body)
+
+ @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 3}})
+ def test_send_join_contributes_to_room_join_rate_limit_and_is_limited(self) -> None:
+ # Make two make_join requests up front. (These are rate limited, but do not
+ # contribute to the rate limit.)
+ join_event_dicts = []
+ for i in range(2):
+ joining_user = f"@misspiggy{i}:{self.OTHER_SERVER_NAME}"
+ join_result = self._make_join(joining_user)
+ join_event_dict = join_result["event"]
+ self.add_hashes_and_signatures_from_other_server(
+ join_event_dict,
+ KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
+ )
+ join_event_dicts.append(join_event_dict)
+
+ # In the test setup, two users join the room. Since the rate limiter burst
+ # count is 3, the first send_join should be accepted...
+ channel = self.make_signed_federation_request(
+ "PUT",
+ f"/_matrix/federation/v2/send_join/{self._room_id}/join0",
+ content=join_event_dicts[0],
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # ... but the second should be denied.
+ channel = self.make_signed_federation_request(
+ "PUT",
+ f"/_matrix/federation/v2/send_join/{self._room_id}/join1",
+ content=join_event_dicts[1],
+ )
+ self.assertEqual(channel.code, HTTPStatus.TOO_MANY_REQUESTS, channel.json_body)
+
+ # NB: we could write a test which checks that the send_join event is seen
+ # by other workers over replication, and that they update their rate limit
+ # buckets accordingly. I'm going to assume that the join event gets sent over
+ # replication, at which point the tests.handlers.room_member test
+ # test_local_users_joining_on_another_worker_contribute_to_rate_limit
+ # is probably sufficient to reassure that the bucket is updated.
+
def _create_acl_event(content):
return make_event_from_dict(
diff --git a/tests/federation/transport/server/test__base.py b/tests/federation/transport/server/test__base.py
index e63885c1c9..e88e5d8bb3 100644
--- a/tests/federation/transport/server/test__base.py
+++ b/tests/federation/transport/server/test__base.py
@@ -18,13 +18,14 @@ from typing import Dict, List, Tuple
from synapse.api.errors import Codes
from synapse.federation.transport.server import BaseFederationServlet
from synapse.federation.transport.server._base import Authenticator, _parse_auth_header
-from synapse.http.server import JsonResource, cancellable
+from synapse.http.server import JsonResource
from synapse.server import HomeServer
from synapse.types import JsonDict
+from synapse.util.cancellation import cancellable
from synapse.util.ratelimitutils import FederationRateLimiter
from tests import unittest
-from tests.http.server._base import EndpointCancellationTestHelperMixin
+from tests.http.server._base import test_disconnect
class CancellableFederationServlet(BaseFederationServlet):
@@ -54,9 +55,7 @@ class CancellableFederationServlet(BaseFederationServlet):
return HTTPStatus.OK, {"result": True}
-class BaseFederationServletCancellationTests(
- unittest.FederatingHomeserverTestCase, EndpointCancellationTestHelperMixin
-):
+class BaseFederationServletCancellationTests(unittest.FederatingHomeserverTestCase):
"""Tests for `BaseFederationServlet` cancellation."""
skip = "`BaseFederationServlet` does not support cancellation yet."
@@ -86,7 +85,7 @@ class BaseFederationServletCancellationTests(
# request won't be processed.
self.pump()
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
@@ -106,7 +105,7 @@ class BaseFederationServletCancellationTests(
# request won't be processed.
self.pump()
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index d96d5aa138..b17af2725b 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -50,7 +50,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_scheduler = Mock()
hs = Mock()
hs.get_datastores.return_value = Mock(main=self.mock_store)
- self.mock_store.get_received_ts.return_value = make_awaitable(0)
+ self.mock_store.get_appservice_last_pos.return_value = make_awaitable(None)
self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None)
self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable(
None
@@ -76,9 +76,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
event = Mock(
sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
)
- self.mock_store.get_new_events_for_appservice.side_effect = [
- make_awaitable((0, [])),
- make_awaitable((1, [event])),
+ self.mock_store.get_all_new_events_stream.side_effect = [
+ make_awaitable((0, [], {})),
+ make_awaitable((1, [event], {event.event_id: 0})),
]
self.handler.notify_interested_services(RoomStreamToken(None, 1))
@@ -95,8 +95,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
self.mock_as_api.query_user.return_value = make_awaitable(True)
- self.mock_store.get_new_events_for_appservice.side_effect = [
- make_awaitable((0, [event])),
+ self.mock_store.get_all_new_events_stream.side_effect = [
+ make_awaitable((0, [event], {event.event_id: 0})),
]
self.handler.notify_interested_services(RoomStreamToken(None, 0))
@@ -112,8 +112,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
self.mock_as_api.query_user.return_value = make_awaitable(True)
- self.mock_store.get_new_events_for_appservice.side_effect = [
- make_awaitable((0, [event])),
+ self.mock_store.get_all_new_events_stream.side_effect = [
+ make_awaitable((0, [event], {event.event_id: 0})),
]
self.handler.notify_interested_services(RoomStreamToken(None, 0))
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 67a7829769..7106799d44 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -38,7 +38,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
# MAU tests
# AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
- self.auth_blocking = hs.get_auth()._auth_blocking
+ self.auth_blocking = hs.get_auth_blocking()
self.auth_blocking._max_mau_value = 50
self.small_number_of_users = 1
diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py
index 7586e472b5..7b9b711521 100644
--- a/tests/handlers/test_deactivate_account.py
+++ b/tests/handlers/test_deactivate_account.py
@@ -11,12 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
-from typing import Any, Dict
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import AccountDataTypes
+from synapse.push.baserules import PushRule
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest import admin
from synapse.rest.client import account, login
@@ -58,7 +57,7 @@ class DeactivateAccountTestCase(HomeserverTestCase):
access_token=self.token,
)
- self.assertEqual(req.code, HTTPStatus.OK, req)
+ self.assertEqual(req.code, 200, req)
def test_global_account_data_deleted_upon_deactivation(self) -> None:
"""
@@ -131,12 +130,12 @@ class DeactivateAccountTestCase(HomeserverTestCase):
),
)
- def _is_custom_rule(self, push_rule: Dict[str, Any]) -> bool:
+ def _is_custom_rule(self, push_rule: PushRule) -> bool:
"""
Default rules start with a dot: such as .m.rule and .im.vector.
This function returns true iff a rule is custom (not default).
"""
- return "/." not in push_rule["rule_id"]
+ return "/." not in push_rule.rule_id
def test_push_rules_deleted_upon_account_deactivation(self) -> None:
"""
@@ -158,22 +157,21 @@ class DeactivateAccountTestCase(HomeserverTestCase):
)
# Test the rule exists
- push_rules = self.get_success(self._store.get_push_rules_for_user(self.user))
+ filtered_push_rules = self.get_success(
+ self._store.get_push_rules_for_user(self.user)
+ )
# Filter out default rules; we don't care
- push_rules = list(filter(self._is_custom_rule, push_rules))
+ push_rules = [r for r, _ in filtered_push_rules if self._is_custom_rule(r)]
# Check our rule made it
self.assertEqual(
push_rules,
[
- {
- "user_name": "@user:test",
- "rule_id": "personal.override.rule1",
- "priority_class": 5,
- "priority": 0,
- "conditions": [],
- "actions": [],
- "default": False,
- }
+ PushRule(
+ rule_id="personal.override.rule1",
+ priority_class=5,
+ conditions=[],
+ actions=[],
+ )
],
push_rules,
)
@@ -181,9 +179,11 @@ class DeactivateAccountTestCase(HomeserverTestCase):
# Request the deactivation of our account
self._deactivate_my_account()
- push_rules = self.get_success(self._store.get_push_rules_for_user(self.user))
+ filtered_push_rules = self.get_success(
+ self._store.get_push_rules_for_user(self.user)
+ )
# Filter out default rules; we don't care
- push_rules = list(filter(self._is_custom_rule, push_rules))
+ push_rules = [r for r, _ in filtered_push_rules if self._is_custom_rule(r)]
# Check our rule no longer exists
self.assertEqual(push_rules, [], push_rules)
@@ -322,3 +322,18 @@ class DeactivateAccountTestCase(HomeserverTestCase):
)
),
)
+
+ def test_deactivate_account_needs_auth(self) -> None:
+ """
+ Tests that making a request to /deactivate with an empty body
+ succeeds in starting the user-interactive auth flow.
+ """
+ req = self.make_request(
+ "POST",
+ "account/deactivate",
+ {},
+ access_token=self.token,
+ )
+
+ self.assertEqual(req.code, 401, req)
+ self.assertEqual(req.json_body["flows"], [{"stages": ["m.login.password"]}])
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 53d49ca896..3b72c4c9d0 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -481,17 +481,13 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
return config
- def prepare(
- self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
- ) -> HomeServer:
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
self.allowed_access_token = self.login(self.allowed_localpart, "pass")
self.denied_user_id = self.register_user("denied", "pass")
self.denied_access_token = self.login("denied", "pass")
- return hs
-
def test_denied_without_publication_permission(self) -> None:
"""
Try to create a room, register an alias for it, and publish it,
@@ -575,9 +571,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
servlets = [directory.register_servlets, room.register_servlets]
- def prepare(
- self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
- ) -> HomeServer:
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request(
@@ -588,8 +582,6 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
self.room_list_handler = hs.get_room_list_handler()
self.directory_handler = hs.get_directory_handler()
- return hs
-
def test_disabling_room_list(self) -> None:
self.room_list_handler.enable_room_list_search = True
self.directory_handler.enable_room_list_search = True
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index e0eda545b9..745750b1d7 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List, cast
+from typing import cast
from unittest import TestCase
+from unittest.mock import Mock, patch
from twisted.test.proto_helpers import MemoryReactor
@@ -22,6 +23,7 @@ from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseErro
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, make_event_from_dict
from synapse.federation.federation_base import event_from_pdu_json
+from synapse.federation.federation_client import SendJoinResult
from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
from synapse.rest.client import login, room
@@ -30,7 +32,7 @@ from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
-from tests.test_utils import event_injection
+from tests.test_utils import event_injection, make_awaitable
logger = logging.getLogger(__name__)
@@ -50,8 +52,6 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
hs = self.setup_test_homeserver(federation_http_client=None)
self.handler = hs.get_federation_handler()
self.store = hs.get_datastores().main
- self.state_storage_controller = hs.get_storage_controllers().state
- self._event_auth_handler = hs.get_event_auth_handler()
return hs
def test_exchange_revoked_invite(self) -> None:
@@ -119,7 +119,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id)
# check the state group
- sg = self.successResultOf(
+ sg = self.get_success(
self.store._get_state_group_for_event(join_event.event_id)
)
@@ -149,7 +149,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
self.assertIsNotNone(e.rejected_reason)
# ... and the state group should be the same as before
- sg2 = self.successResultOf(self.store._get_state_group_for_event(ev.event_id))
+ sg2 = self.get_success(self.store._get_state_group_for_event(ev.event_id))
self.assertEqual(sg, sg2)
@@ -172,7 +172,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id)
# check the state group
- sg = self.successResultOf(
+ sg = self.get_success(
self.store._get_state_group_for_event(join_event.event_id)
)
@@ -203,7 +203,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
self.assertIsNotNone(e.rejected_reason)
# ... and the state group should be the same as before
- sg2 = self.successResultOf(self.store._get_state_group_for_event(ev.event_id))
+ sg2 = self.get_success(self.store._get_state_group_for_event(ev.event_id))
self.assertEqual(sg, sg2)
@@ -225,9 +225,10 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# we need a user on the remote server to be a member, so that we can send
# extremity-causing events.
+ remote_server_user_id = f"@user:{self.OTHER_SERVER_NAME}"
self.get_success(
event_injection.inject_member_event(
- self.hs, room_id, f"@user:{self.OTHER_SERVER_NAME}", "join"
+ self.hs, room_id, remote_server_user_id, "join"
)
)
@@ -247,9 +248,15 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# create more than is 5 which corresponds to the number of backward
# extremities we slice off in `_maybe_backfill_inner`
federation_event_handler = self.hs.get_federation_event_handler()
+ auth_events = [
+ ev
+ for ev in current_state
+ if (ev.type, ev.state_key)
+ in {("m.room.create", ""), ("m.room.member", remote_server_user_id)}
+ ]
for _ in range(0, 8):
event = make_event_from_dict(
- self.add_hashes_and_signatures(
+ self.add_hashes_and_signatures_from_other_server(
{
"origin_server_ts": 1,
"type": "m.room.message",
@@ -258,15 +265,14 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
"body": "message connected to fake event",
},
"room_id": room_id,
- "sender": f"@user:{self.OTHER_SERVER_NAME}",
+ "sender": remote_server_user_id,
"prev_events": [
ev1.event_id,
# We're creating an backward extremity each time thanks
# to this fake event
generate_fake_event_id(),
],
- # lazy: *everything* is an auth event
- "auth_events": [ev.event_id for ev in current_state],
+ "auth_events": [ev.event_id for ev in auth_events],
"depth": ev1.depth + 1,
},
room_version,
@@ -276,13 +282,21 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# we poke this directly into _process_received_pdu, to avoid the
# federation handler wanting to backfill the fake event.
+ state_handler = self.hs.get_state_handler()
+ context = self.get_success(
+ state_handler.compute_event_context(
+ event,
+ state_ids_before_event={
+ (e.type, e.state_key): e.event_id for e in current_state
+ },
+ partial_state=False,
+ )
+ )
self.get_success(
federation_event_handler._process_received_pdu(
self.OTHER_SERVER_NAME,
event,
- state_ids={
- (e.type, e.state_key): e.event_id for e in current_state
- },
+ context,
)
)
@@ -308,142 +322,6 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
)
self.get_success(d)
- def test_backfill_floating_outlier_membership_auth(self) -> None:
- """
- As the local homeserver, check that we can properly process a federated
- event from the OTHER_SERVER with auth_events that include a floating
- membership event from the OTHER_SERVER.
-
- Regression test, see #10439.
- """
- OTHER_SERVER = "otherserver"
- OTHER_USER = "@otheruser:" + OTHER_SERVER
-
- # create the room
- user_id = self.register_user("kermit", "test")
- tok = self.login("kermit", "test")
- room_id = self.helper.create_room_as(
- room_creator=user_id,
- is_public=True,
- tok=tok,
- extra_content={
- "preset": "public_chat",
- },
- )
- room_version = self.get_success(self.store.get_room_version(room_id))
-
- prev_event_ids = self.get_success(self.store.get_prev_events_for_room(room_id))
- (
- most_recent_prev_event_id,
- most_recent_prev_event_depth,
- ) = self.get_success(self.store.get_max_depth_of(prev_event_ids))
- # mapping from (type, state_key) -> state_event_id
- assert most_recent_prev_event_id is not None
- prev_state_map = self.get_success(
- self.state_storage_controller.get_state_ids_for_event(
- most_recent_prev_event_id
- )
- )
- # List of state event ID's
- prev_state_ids = list(prev_state_map.values())
- auth_event_ids = prev_state_ids
- auth_events = list(
- self.get_success(self.store.get_events(auth_event_ids)).values()
- )
-
- # build a floating outlier member state event
- fake_prev_event_id = "$" + random_string(43)
- member_event_dict = {
- "type": EventTypes.Member,
- "content": {
- "membership": "join",
- },
- "state_key": OTHER_USER,
- "room_id": room_id,
- "sender": OTHER_USER,
- "depth": most_recent_prev_event_depth,
- "prev_events": [fake_prev_event_id],
- "origin_server_ts": self.clock.time_msec(),
- "signatures": {OTHER_SERVER: {"ed25519:key_version": "SomeSignatureHere"}},
- }
- builder = self.hs.get_event_builder_factory().for_room_version(
- room_version, member_event_dict
- )
- member_event = self.get_success(
- builder.build(
- prev_event_ids=member_event_dict["prev_events"],
- auth_event_ids=self._event_auth_handler.compute_auth_events(
- builder,
- prev_state_map,
- for_verification=False,
- ),
- depth=member_event_dict["depth"],
- )
- )
- # Override the signature added from "test" homeserver that we created the event with
- member_event.signatures = member_event_dict["signatures"]
-
- # Add the new member_event to the StateMap
- updated_state_map = dict(prev_state_map)
- updated_state_map[
- (member_event.type, member_event.state_key)
- ] = member_event.event_id
- auth_events.append(member_event)
-
- # build and send an event authed based on the member event
- message_event_dict = {
- "type": EventTypes.Message,
- "content": {},
- "room_id": room_id,
- "sender": OTHER_USER,
- "depth": most_recent_prev_event_depth,
- "prev_events": prev_event_ids.copy(),
- "origin_server_ts": self.clock.time_msec(),
- "signatures": {OTHER_SERVER: {"ed25519:key_version": "SomeSignatureHere"}},
- }
- builder = self.hs.get_event_builder_factory().for_room_version(
- room_version, message_event_dict
- )
- message_event = self.get_success(
- builder.build(
- prev_event_ids=message_event_dict["prev_events"],
- auth_event_ids=self._event_auth_handler.compute_auth_events(
- builder,
- updated_state_map,
- for_verification=False,
- ),
- depth=message_event_dict["depth"],
- )
- )
- # Override the signature added from "test" homeserver that we created the event with
- message_event.signatures = message_event_dict["signatures"]
-
- # Stub the /event_auth response from the OTHER_SERVER
- async def get_event_auth(
- destination: str, room_id: str, event_id: str
- ) -> List[EventBase]:
- return [
- event_from_pdu_json(ae.get_pdu_json(), room_version=room_version)
- for ae in auth_events
- ]
-
- self.handler.federation_client.get_event_auth = get_event_auth # type: ignore[assignment]
-
- with LoggingContext("receive_pdu"):
- # Fake the OTHER_SERVER federating the message event over to our local homeserver
- d = run_in_background(
- self.hs.get_federation_event_handler().on_receive_pdu,
- OTHER_SERVER,
- message_event,
- )
- self.get_success(d)
-
- # Now try and get the events on our local homeserver
- stored_event = self.get_success(
- self.store.get_event(message_event.event_id, allow_none=True)
- )
- self.assertTrue(stored_event is not None)
-
@unittest.override_config(
{"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
)
@@ -580,3 +458,121 @@ class EventFromPduTestCase(TestCase):
},
RoomVersions.V6,
)
+
+
+class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
+ def test_failed_partial_join_is_clean(self) -> None:
+ """
+ Tests that, when failing to partial-join a room, we don't get stuck with
+ a partial-state flag on a room.
+ """
+
+ fed_handler = self.hs.get_federation_handler()
+ fed_client = fed_handler.federation_client
+
+ room_id = "!room:example.com"
+ membership_event = make_event_from_dict(
+ {
+ "room_id": room_id,
+ "type": "m.room.member",
+ "sender": "@alice:test",
+ "state_key": "@alice:test",
+ "content": {"membership": "join"},
+ },
+ RoomVersions.V10,
+ )
+
+ mock_make_membership_event = Mock(
+ return_value=make_awaitable(
+ (
+ "example.com",
+ membership_event,
+ RoomVersions.V10,
+ )
+ )
+ )
+
+ EVENT_CREATE = make_event_from_dict(
+ {
+ "room_id": room_id,
+ "type": "m.room.create",
+ "sender": "@kristina:example.com",
+ "state_key": "",
+ "depth": 0,
+ "content": {"creator": "@kristina:example.com", "room_version": "10"},
+ "auth_events": [],
+ "origin_server_ts": 1,
+ },
+ room_version=RoomVersions.V10,
+ )
+ EVENT_CREATOR_MEMBERSHIP = make_event_from_dict(
+ {
+ "room_id": room_id,
+ "type": "m.room.member",
+ "sender": "@kristina:example.com",
+ "state_key": "@kristina:example.com",
+ "content": {"membership": "join"},
+ "depth": 1,
+ "prev_events": [EVENT_CREATE.event_id],
+ "auth_events": [EVENT_CREATE.event_id],
+ "origin_server_ts": 1,
+ },
+ room_version=RoomVersions.V10,
+ )
+ EVENT_INVITATION_MEMBERSHIP = make_event_from_dict(
+ {
+ "room_id": room_id,
+ "type": "m.room.member",
+ "sender": "@kristina:example.com",
+ "state_key": "@alice:test",
+ "content": {"membership": "invite"},
+ "depth": 2,
+ "prev_events": [EVENT_CREATOR_MEMBERSHIP.event_id],
+ "auth_events": [
+ EVENT_CREATE.event_id,
+ EVENT_CREATOR_MEMBERSHIP.event_id,
+ ],
+ "origin_server_ts": 1,
+ },
+ room_version=RoomVersions.V10,
+ )
+ mock_send_join = Mock(
+ return_value=make_awaitable(
+ SendJoinResult(
+ membership_event,
+ "example.com",
+ state=[
+ EVENT_CREATE,
+ EVENT_CREATOR_MEMBERSHIP,
+ EVENT_INVITATION_MEMBERSHIP,
+ ],
+ auth_chain=[
+ EVENT_CREATE,
+ EVENT_CREATOR_MEMBERSHIP,
+ EVENT_INVITATION_MEMBERSHIP,
+ ],
+ partial_state=True,
+ servers_in_room=["example.com"],
+ )
+ )
+ )
+
+ with patch.object(
+ fed_client, "make_membership_event", mock_make_membership_event
+ ), patch.object(fed_client, "send_join", mock_send_join):
+ # Join and check that our join event is rejected
+ # (The join event is rejected because it doesn't have any signatures)
+ join_exc = self.get_failure(
+ fed_handler.do_invite_join(["example.com"], room_id, "@alice:test", {}),
+ SynapseError,
+ )
+ self.assertIn("Join event was rejected", str(join_exc))
+
+ store = self.hs.get_datastores().main
+
+ # Check that we don't have a left-over partial_state entry.
+ self.assertFalse(
+ self.get_success(store.is_partial_state_room(room_id)),
+ f"Stale partial-stated room flag left over for {room_id} after a"
+ f" failed do_invite_join!",
+ )
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 1a36c25c41..51c8dd6498 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -98,14 +98,13 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
auth_event_ids = [
initial_state_map[("m.room.create", "")],
initial_state_map[("m.room.power_levels", "")],
- initial_state_map[("m.room.join_rules", "")],
member_event.event_id,
]
# mock up a load of state events which we are missing
state_events = [
make_event_from_dict(
- self.add_hashes_and_signatures(
+ self.add_hashes_and_signatures_from_other_server(
{
"type": "test_state_type",
"state_key": f"state_{i}",
@@ -132,7 +131,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
# Depending on the test, we either persist this upfront (as an outlier),
# or let the server request it.
prev_event = make_event_from_dict(
- self.add_hashes_and_signatures(
+ self.add_hashes_and_signatures_from_other_server(
{
"type": "test_regular_type",
"room_id": room_id,
@@ -166,7 +165,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
# mock up a regular event to pass into _process_pulled_event
pulled_event = make_event_from_dict(
- self.add_hashes_and_signatures(
+ self.add_hashes_and_signatures_from_other_server(
{
"type": "test_regular_type",
"room_id": room_id,
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 44da96c792..986b50ce0c 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -314,4 +314,4 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST", path, content={}, access_token=self.access_token
)
- self.assertEqual(int(channel.result["code"]), 403)
+ self.assertEqual(channel.code, 403)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 1231aed944..e6cd3af7b7 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -25,7 +25,7 @@ from synapse.handlers.sso import MappingException
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID
from synapse.util import Clock
-from synapse.util.macaroons import get_value_from_macaroon
+from synapse.util.macaroons import OidcSessionData, get_value_from_macaroon
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
@@ -1227,7 +1227,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
) -> str:
from synapse.handlers.oidc import OidcSessionData
- return self.handler._token_generator.generate_oidc_session_token(
+ return self.handler._macaroon_generator.generate_oidc_session_token(
state=state,
session_data=OidcSessionData(
idp_id="oidc",
@@ -1251,7 +1251,6 @@ async def _make_callback_with_userinfo(
userinfo: the OIDC userinfo dict
client_redirect_url: the URL to redirect to on success.
"""
- from synapse.handlers.oidc import OidcSessionData
handler = hs.get_oidc_handler()
provider = handler._providers["oidc"]
@@ -1260,7 +1259,7 @@ async def _make_callback_with_userinfo(
provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
state = "state"
- session = handler._token_generator.generate_oidc_session_token(
+ session = handler._macaroon_generator.generate_oidc_session_token(
state=state,
session_data=OidcSessionData(
idp_id="oidc",
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 82b3bb3b73..75934b1707 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -14,13 +14,13 @@
"""Tests for the password_auth_provider interface"""
+from http import HTTPStatus
from typing import Any, Type, Union
from unittest.mock import Mock
import synapse
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
-from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.module_api import ModuleApi
from synapse.rest.client import account, devices, login, logout, register
from synapse.types import JsonDict, UserID
@@ -166,16 +166,6 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
super().setUp()
- def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver()
- # Load the modules into the homeserver
- module_api = hs.get_module_api()
- for module, config in hs.config.modules.loaded_modules:
- module(config=config, api=module_api)
- load_legacy_password_auth_providers(hs)
-
- return hs
-
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_password_only_auth_progiver_login_legacy(self):
self.password_only_auth_provider_login_test_body()
@@ -188,14 +178,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# check_password must return an awaitable
mock_password_provider.check_password.return_value = make_awaitable(True)
channel = self._send_password_login("u", "p")
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@u:test", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
mock_password_provider.reset_mock()
# login with mxid should work too
channel = self._send_password_login("@u:bz", "p")
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@u:bz", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
mock_password_provider.reset_mock()
@@ -204,7 +194,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# in these cases, but at least we can guard against the API changing
# unexpectedly
channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with(
"@ USER🙂NAME :test", " pASS😢word "
@@ -258,10 +248,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# check_password must return an awaitable
mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._send_password_login("u", "p")
- self.assertEqual(channel.code, 403, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
channel = self._send_password_login("localuser", "localpass")
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@localuser:test", channel.json_body["user_id"])
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
@@ -382,7 +372,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("u", "p")
- self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_password.assert_not_called()
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
@@ -406,14 +396,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login with missing param should be rejected
channel = self._send_login("test.login_type", "u")
- self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", None)
)
channel = self._send_login("test.login_type", "u", test_field="y")
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"}
@@ -427,7 +417,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
("@ MALFORMED! :bz", None)
)
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
" USER🙂NAME ", "test.login_type", {"test_field": " abc "}
@@ -510,7 +500,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
("@user:bz", callback)
)
channel = self._send_login("test.login_type", "u", test_field="y")
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"}
@@ -549,7 +539,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
- self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()
@override_config(
@@ -584,7 +574,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
- self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()
@override_config(
@@ -615,7 +605,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
- self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_password.assert_not_called()
@@ -646,13 +636,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
("@localuser:test", None)
)
channel = self._send_login("test.login_type", "localuser", test_field="")
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
tok1 = channel.json_body["access_token"]
channel = self._send_login(
"test.login_type", "localuser", test_field="", device_id="dev2"
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
# make the initial request which returns a 401
channel = self._delete_device(tok1, "dev2")
@@ -721,7 +711,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# password login shouldn't work and should be rejected with a 400
# ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
- self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
def test_on_logged_out(self):
"""Tests that the on_logged_out callback is called when the user logs out."""
@@ -884,7 +874,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
},
access_token=tok,
)
- self.assertEqual(channel.code, 403, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
self.assertEqual(
channel.json_body["errcode"],
Codes.THREEPID_DENIED,
@@ -906,7 +896,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
},
access_token=tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertIn("sid", channel.json_body)
m.assert_called_once_with("email", "bar@test.com", registration)
@@ -949,12 +939,12 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"register",
{"auth": {"session": session, "type": LoginType.DUMMY}},
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
return channel.json_body
def _get_login_flows(self) -> JsonDict:
channel = self.make_request("GET", "/_matrix/client/r0/login")
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
return channel.json_body["flows"]
def _send_password_login(self, user: str, password: str) -> FakeChannel:
diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py
index a95868b5c0..b55238650c 100644
--- a/tests/handlers/test_receipts.py
+++ b/tests/handlers/test_receipts.py
@@ -25,7 +25,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.event_source = hs.get_event_sources().sources.receipt
- def test_filters_out_private_receipt(self):
+ def test_filters_out_private_receipt(self) -> None:
self._test_filters_private(
[
{
@@ -45,7 +45,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
[],
)
- def test_filters_out_private_receipt_and_ignores_rest(self):
+ def test_filters_out_private_receipt_and_ignores_rest(self) -> None:
self._test_filters_private(
[
{
@@ -84,7 +84,9 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
- def test_filters_out_event_with_only_private_receipts_and_ignores_the_rest(self):
+ def test_filters_out_event_with_only_private_receipts_and_ignores_the_rest(
+ self,
+ ) -> None:
self._test_filters_private(
[
{
@@ -125,7 +127,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
- def test_handles_empty_event(self):
+ def test_handles_empty_event(self) -> None:
self._test_filters_private(
[
{
@@ -160,7 +162,9 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
- def test_filters_out_receipt_event_with_only_private_receipt_and_ignores_rest(self):
+ def test_filters_out_receipt_event_with_only_private_receipt_and_ignores_rest(
+ self,
+ ) -> None:
self._test_filters_private(
[
{
@@ -207,7 +211,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
- def test_handles_string_data(self):
+ def test_handles_string_data(self) -> None:
"""
Tests that an invalid shape for read-receipts is handled.
Context: https://github.com/matrix-org/synapse/issues/10603
@@ -242,7 +246,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
- def test_leaves_our_private_and_their_public(self):
+ def test_leaves_our_private_and_their_public(self) -> None:
self._test_filters_private(
[
{
@@ -296,7 +300,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
- def test_we_do_not_mutate(self):
+ def test_we_do_not_mutate(self) -> None:
"""Ensure the input values are not modified."""
events = [
{
@@ -320,7 +324,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
def _test_filters_private(
self, events: List[JsonDict], expected_output: List[JsonDict]
- ):
+ ) -> None:
"""Tests that the _filter_out_private returns the expected output"""
filtered_events = self.event_source.filter_out_private_receipts(
events, "@me:server.org"
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index b6ba19c739..86b3d51975 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -22,7 +22,6 @@ from synapse.api.errors import (
ResourceLimitError,
SynapseError,
)
-from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, RoomID, UserID, create_requester
@@ -144,12 +143,6 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
config=hs_config, federation_client=self.mock_federation_client
)
- load_legacy_spam_checkers(hs)
-
- module_api = hs.get_module_api()
- for module, config in hs.config.modules.loaded_modules:
- module(config=config, api=module_api)
-
return hs
def prepare(self, reactor, clock, hs):
@@ -699,7 +692,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"""
if localpart is None:
raise SynapseError(400, "Request must include user id")
- await self.hs.get_auth().check_auth_blocking()
+ await self.hs.get_auth_blocking().check_auth_blocking()
need_register = True
try:
diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py
new file mode 100644
index 0000000000..6bbfd5dc84
--- /dev/null
+++ b/tests/handlers/test_room_member.py
@@ -0,0 +1,378 @@
+from unittest.mock import Mock, patch
+
+from twisted.test.proto_helpers import MemoryReactor
+
+import synapse.rest.admin
+import synapse.rest.client.login
+import synapse.rest.client.room
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.errors import LimitExceededError, SynapseError
+from synapse.crypto.event_signing import add_hashes_and_signatures
+from synapse.events import FrozenEventV3
+from synapse.federation.federation_client import SendJoinResult
+from synapse.server import HomeServer
+from synapse.types import UserID, create_requester
+from synapse.util import Clock
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.server import make_request
+from tests.test_utils import make_awaitable
+from tests.unittest import (
+ FederatingHomeserverTestCase,
+ HomeserverTestCase,
+ override_config,
+)
+
+
+class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ synapse.rest.client.login.register_servlets,
+ synapse.rest.client.room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.handler = hs.get_room_member_handler()
+
+ # Create three users.
+ self.alice = self.register_user("alice", "pass")
+ self.alice_token = self.login("alice", "pass")
+ self.bob = self.register_user("bob", "pass")
+ self.bob_token = self.login("bob", "pass")
+ self.chris = self.register_user("chris", "pass")
+ self.chris_token = self.login("chris", "pass")
+
+ # Create a room on this homeserver. Note that this counts as a join: it
+ # contributes to the rate limter's count of actions
+ self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token)
+
+ self.intially_unjoined_room_id = f"!example:{self.OTHER_SERVER_NAME}"
+
+ @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}})
+ def test_local_user_local_joins_contribute_to_limit_and_are_limited(self) -> None:
+ # The rate limiter has accumulated one token from Alice's join after the create
+ # event.
+ # Try joining the room as Bob.
+ self.get_success(
+ self.handler.update_membership(
+ requester=create_requester(self.bob),
+ target=UserID.from_string(self.bob),
+ room_id=self.room_id,
+ action=Membership.JOIN,
+ )
+ )
+
+ # The rate limiter bucket is full. A second join should be denied.
+ self.get_failure(
+ self.handler.update_membership(
+ requester=create_requester(self.chris),
+ target=UserID.from_string(self.chris),
+ room_id=self.room_id,
+ action=Membership.JOIN,
+ ),
+ LimitExceededError,
+ )
+
+ @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}})
+ def test_local_user_profile_edits_dont_contribute_to_limit(self) -> None:
+ # The rate limiter has accumulated one token from Alice's join after the create
+ # event. Alice should still be able to change her displayname.
+ self.get_success(
+ self.handler.update_membership(
+ requester=create_requester(self.alice),
+ target=UserID.from_string(self.alice),
+ room_id=self.room_id,
+ action=Membership.JOIN,
+ content={"displayname": "Alice Cooper"},
+ )
+ )
+
+ # Still room in the limiter bucket. Chris's join should be accepted.
+ self.get_success(
+ self.handler.update_membership(
+ requester=create_requester(self.chris),
+ target=UserID.from_string(self.chris),
+ room_id=self.room_id,
+ action=Membership.JOIN,
+ )
+ )
+
+ @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 1}})
+ def test_remote_joins_contribute_to_rate_limit(self) -> None:
+ # Join once, to fill the rate limiter bucket.
+ #
+ # To do this we have to mock the responses from the remote homeserver.
+ # We also patch out a bunch of event checks on our end. All we're really
+ # trying to check here is that remote joins will bump the rate limter when
+ # they are persisted.
+ create_event_source = {
+ "auth_events": [],
+ "content": {
+ "creator": f"@creator:{self.OTHER_SERVER_NAME}",
+ "room_version": self.hs.config.server.default_room_version.identifier,
+ },
+ "depth": 0,
+ "origin_server_ts": 0,
+ "prev_events": [],
+ "room_id": self.intially_unjoined_room_id,
+ "sender": f"@creator:{self.OTHER_SERVER_NAME}",
+ "state_key": "",
+ "type": EventTypes.Create,
+ }
+ self.add_hashes_and_signatures_from_other_server(
+ create_event_source,
+ self.hs.config.server.default_room_version,
+ )
+ create_event = FrozenEventV3(
+ create_event_source,
+ self.hs.config.server.default_room_version,
+ {},
+ None,
+ )
+
+ join_event_source = {
+ "auth_events": [create_event.event_id],
+ "content": {"membership": "join"},
+ "depth": 1,
+ "origin_server_ts": 100,
+ "prev_events": [create_event.event_id],
+ "sender": self.bob,
+ "state_key": self.bob,
+ "room_id": self.intially_unjoined_room_id,
+ "type": EventTypes.Member,
+ }
+ add_hashes_and_signatures(
+ self.hs.config.server.default_room_version,
+ join_event_source,
+ self.hs.hostname,
+ self.hs.signing_key,
+ )
+ join_event = FrozenEventV3(
+ join_event_source,
+ self.hs.config.server.default_room_version,
+ {},
+ None,
+ )
+
+ mock_make_membership_event = Mock(
+ return_value=make_awaitable(
+ (
+ self.OTHER_SERVER_NAME,
+ join_event,
+ self.hs.config.server.default_room_version,
+ )
+ )
+ )
+ mock_send_join = Mock(
+ return_value=make_awaitable(
+ SendJoinResult(
+ join_event,
+ self.OTHER_SERVER_NAME,
+ state=[create_event],
+ auth_chain=[create_event],
+ partial_state=False,
+ servers_in_room=[],
+ )
+ )
+ )
+
+ with patch.object(
+ self.handler.federation_handler.federation_client,
+ "make_membership_event",
+ mock_make_membership_event,
+ ), patch.object(
+ self.handler.federation_handler.federation_client,
+ "send_join",
+ mock_send_join,
+ ), patch(
+ "synapse.event_auth._is_membership_change_allowed",
+ return_value=None,
+ ), patch(
+ "synapse.handlers.federation_event.check_state_dependent_auth_rules",
+ return_value=None,
+ ):
+ self.get_success(
+ self.handler.update_membership(
+ requester=create_requester(self.bob),
+ target=UserID.from_string(self.bob),
+ room_id=self.intially_unjoined_room_id,
+ action=Membership.JOIN,
+ remote_room_hosts=[self.OTHER_SERVER_NAME],
+ )
+ )
+
+ # Try to join as Chris. Should get denied.
+ self.get_failure(
+ self.handler.update_membership(
+ requester=create_requester(self.chris),
+ target=UserID.from_string(self.chris),
+ room_id=self.intially_unjoined_room_id,
+ action=Membership.JOIN,
+ remote_room_hosts=[self.OTHER_SERVER_NAME],
+ ),
+ LimitExceededError,
+ )
+
+ # TODO: test that remote joins to a room are rate limited.
+ # Could do this by setting the burst count to 1, then:
+ # - remote-joining a room
+ # - immediately leaving
+ # - trying to remote-join again.
+
+
+class TestReplicatedJoinsLimitedByPerRoomRateLimiter(BaseMultiWorkerStreamTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ synapse.rest.client.login.register_servlets,
+ synapse.rest.client.room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.handler = hs.get_room_member_handler()
+
+ # Create three users.
+ self.alice = self.register_user("alice", "pass")
+ self.alice_token = self.login("alice", "pass")
+ self.bob = self.register_user("bob", "pass")
+ self.bob_token = self.login("bob", "pass")
+ self.chris = self.register_user("chris", "pass")
+ self.chris_token = self.login("chris", "pass")
+
+ # Create a room on this homeserver.
+ # Note that this counts as a
+ self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token)
+ self.intially_unjoined_room_id = "!example:otherhs"
+
+ @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}})
+ def test_local_users_joining_on_another_worker_contribute_to_rate_limit(
+ self,
+ ) -> None:
+ # The rate limiter has accumulated one token from Alice's join after the create
+ # event.
+ self.replicate()
+
+ # Spawn another worker and have bob join via it.
+ worker_app = self.make_worker_hs(
+ "synapse.app.generic_worker", extra_config={"worker_name": "other worker"}
+ )
+ worker_site = self._hs_to_site[worker_app]
+ channel = make_request(
+ self.reactor,
+ worker_site,
+ "POST",
+ f"/_matrix/client/v3/rooms/{self.room_id}/join",
+ access_token=self.bob_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # wait for join to arrive over replication
+ self.replicate()
+
+ # Try to join as Chris on the worker. Should get denied because Alice
+ # and Bob have both joined the room.
+ self.get_failure(
+ worker_app.get_room_member_handler().update_membership(
+ requester=create_requester(self.chris),
+ target=UserID.from_string(self.chris),
+ room_id=self.room_id,
+ action=Membership.JOIN,
+ ),
+ LimitExceededError,
+ )
+
+ # Try to join as Chris on the original worker. Should get denied because Alice
+ # and Bob have both joined the room.
+ self.get_failure(
+ self.handler.update_membership(
+ requester=create_requester(self.chris),
+ target=UserID.from_string(self.chris),
+ room_id=self.room_id,
+ action=Membership.JOIN,
+ ),
+ LimitExceededError,
+ )
+
+
+class RoomMemberMasterHandlerTestCase(HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ synapse.rest.client.login.register_servlets,
+ synapse.rest.client.room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.handler = hs.get_room_member_handler()
+ self.store = hs.get_datastores().main
+
+ # Create two users.
+ self.alice = self.register_user("alice", "pass")
+ self.alice_ID = UserID.from_string(self.alice)
+ self.alice_token = self.login("alice", "pass")
+ self.bob = self.register_user("bob", "pass")
+ self.bob_ID = UserID.from_string(self.bob)
+ self.bob_token = self.login("bob", "pass")
+
+ # Create a room on this homeserver.
+ self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token)
+
+ def test_leave_and_forget(self) -> None:
+ """Tests that forget a room is successfully. The test is performed with two users,
+ as forgetting by the last user respectively after all users had left the
+ is a special edge case."""
+ self.helper.join(self.room_id, user=self.bob, tok=self.bob_token)
+
+ # alice is not the last room member that leaves and forgets the room
+ self.helper.leave(self.room_id, user=self.alice, tok=self.alice_token)
+ self.get_success(self.handler.forget(self.alice_ID, self.room_id))
+ self.assertTrue(
+ self.get_success(self.store.did_forget(self.alice, self.room_id))
+ )
+
+ # the server has not forgotten the room
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room_id))
+ )
+
+ def test_leave_and_forget_last_user(self) -> None:
+ """Tests that forget a room is successfully when the last user has left the room."""
+
+ # alice is the last room member that leaves and forgets the room
+ self.helper.leave(self.room_id, user=self.alice, tok=self.alice_token)
+ self.get_success(self.handler.forget(self.alice_ID, self.room_id))
+ self.assertTrue(
+ self.get_success(self.store.did_forget(self.alice, self.room_id))
+ )
+
+ # the server has forgotten the room
+ self.assertTrue(
+ self.get_success(self.store.is_locally_forgotten_room(self.room_id))
+ )
+
+ def test_forget_when_not_left(self) -> None:
+ """Tests that a user cannot not forgets a room that has not left."""
+ self.get_failure(self.handler.forget(self.alice_ID, self.room_id), SynapseError)
+
+ def test_rejoin_forgotten_by_user(self) -> None:
+ """Test that a user that has forgotten a room can do a re-join.
+ The room was not forgotten from the local server.
+ One local user is still member of the room."""
+ self.helper.join(self.room_id, user=self.bob, tok=self.bob_token)
+
+ self.helper.leave(self.room_id, user=self.alice, tok=self.alice_token)
+ self.get_success(self.handler.forget(self.alice_ID, self.room_id))
+ self.assertTrue(
+ self.get_success(self.store.did_forget(self.alice, self.room_id))
+ )
+
+ # the server has not forgotten the room
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room_id))
+ )
+
+ self.helper.join(self.room_id, user=self.alice, tok=self.alice_token)
+ # TODO: A join to a room does not invalidate the forgotten cache
+ # see https://github.com/matrix-org/synapse/issues/13262
+ self.store.did_forget.invalidate_all()
+ self.assertFalse(
+ self.get_success(self.store.did_forget(self.alice, self.room_id))
+ )
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index 0546655690..aa650756e4 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -178,7 +178,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
result_room_ids.append(result_room["room_id"])
result_children_ids.append(
[
- (cs["room_id"], cs["state_key"])
+ (result_room["room_id"], cs["state_key"])
for cs in result_room["children_state"]
]
)
diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py
index 6f77b1237c..da4bf8b582 100644
--- a/tests/handlers/test_send_email.py
+++ b/tests/handlers/test_send_email.py
@@ -23,7 +23,7 @@ from twisted.internet.defer import ensureDeferred
from twisted.mail import interfaces, smtp
from tests.server import FakeTransport
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, override_config
@implementer(interfaces.IMessageDelivery)
@@ -110,3 +110,58 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
user, msg = message_delivery.messages.pop()
self.assertEqual(str(user), "foo@bar.com")
self.assertIn(b"Subject: test subject", msg)
+
+ @override_config(
+ {
+ "email": {
+ "notif_from": "noreply@test",
+ "force_tls": True,
+ },
+ }
+ )
+ def test_send_email_force_tls(self):
+ """Happy-path test that we can send email to an Implicit TLS server."""
+ h = self.hs.get_send_email_handler()
+ d = ensureDeferred(
+ h.send_email(
+ "foo@bar.com", "test subject", "Tests", "HTML content", "Text content"
+ )
+ )
+ # there should be an attempt to connect to localhost:465
+ self.assertEqual(len(self.reactor.sslClients), 1)
+ (
+ host,
+ port,
+ client_factory,
+ contextFactory,
+ _timeout,
+ _bindAddress,
+ ) = self.reactor.sslClients[0]
+ self.assertEqual(host, "localhost")
+ self.assertEqual(port, 465)
+
+ # wire it up to an SMTP server
+ message_delivery = _DummyMessageDelivery()
+ server_protocol = smtp.ESMTP()
+ server_protocol.delivery = message_delivery
+ # make sure that the server uses the test reactor to set timeouts
+ server_protocol.callLater = self.reactor.callLater # type: ignore[assignment]
+
+ client_protocol = client_factory.buildProtocol(None)
+ client_protocol.makeConnection(FakeTransport(server_protocol, self.reactor))
+ server_protocol.makeConnection(
+ FakeTransport(
+ client_protocol,
+ self.reactor,
+ peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
+ )
+ )
+
+ # the message should now get delivered
+ self.get_success(d, by=0.1)
+
+ # check it arrived
+ self.assertEqual(len(message_delivery.messages), 1)
+ user, msg = message_delivery.messages.pop()
+ self.assertEqual(str(user), "foo@bar.com")
+ self.assertIn(b"Subject: test subject", msg)
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index ecd78fa369..05f9ec3c51 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -46,16 +46,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.get_success(
self.store.db_pool.simple_insert(
"background_updates",
- {"update_name": "populate_stats_prepare", "progress_json": "{}"},
- )
- )
- self.get_success(
- self.store.db_pool.simple_insert(
- "background_updates",
{
"update_name": "populate_stats_process_rooms",
"progress_json": "{}",
- "depends_on": "populate_stats_prepare",
},
)
)
@@ -69,16 +62,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
},
)
)
- self.get_success(
- self.store.db_pool.simple_insert(
- "background_updates",
- {
- "update_name": "populate_stats_cleanup",
- "progress_json": "{}",
- "depends_on": "populate_stats_process_users",
- },
- )
- )
async def get_all_room_state(self):
return await self.store.db_pool.simple_select_list(
@@ -533,7 +516,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
{
"update_name": "populate_stats_process_rooms",
"progress_json": "{}",
- "depends_on": "populate_stats_prepare",
},
)
)
@@ -547,16 +529,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
},
)
)
- self.get_success(
- self.store.db_pool.simple_insert(
- "background_updates",
- {
- "update_name": "populate_stats_cleanup",
- "progress_json": "{}",
- "depends_on": "populate_stats_process_users",
- },
- )
- )
self.wait_for_background_updates()
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index db3302a4c7..e3f38fbcc5 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -45,7 +45,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
- self.auth_blocking = self.hs.get_auth()._auth_blocking
+ self.auth_blocking = self.hs.get_auth_blocking()
def test_wait_for_sync_for_user_auth_blocking(self):
user_id1 = "@user1:test"
@@ -159,7 +159,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Blow away caches (supported room versions can only change due to a restart).
self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
- self.store._get_event_cache.clear()
+ self.get_success(self.store._get_event_cache.clear())
self.store._event_ref.clear()
# The rooms should be excluded from the sync response.
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 7af1333126..8adba29d7f 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -25,7 +25,7 @@ from synapse.api.constants import EduTypes
from synapse.api.errors import AuthError
from synapse.federation.transport.server import TransportLayerServer
from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID, create_requester
+from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util import Clock
from tests import unittest
@@ -117,8 +117,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.room_members = []
- async def check_user_in_room(room_id: str, user_id: str) -> None:
- if user_id not in [u.to_string() for u in self.room_members]:
+ async def check_user_in_room(room_id: str, requester: Requester) -> None:
+ if requester.user.to_string() not in [
+ u.to_string() for u in self.room_members
+ ]:
raise AuthError(401, "User is not in the room")
return None
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
index b9f1a381aa..5726e60cee 100644
--- a/tests/http/server/_base.py
+++ b/tests/http/server/_base.py
@@ -12,89 +12,542 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
-from typing import Any, Callable, Optional, Union
+import inspect
+import itertools
+import logging
+from typing import (
+ Any,
+ Callable,
+ ContextManager,
+ Dict,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ TypeVar,
+ Union,
+)
from unittest import mock
+from unittest.mock import Mock
+from twisted.internet.defer import Deferred
from twisted.internet.error import ConnectionDone
+from twisted.python.failure import Failure
+from twisted.test.proto_helpers import MemoryReactorClock
+from twisted.web.server import Site
from synapse.http.server import (
HTTP_STATUS_REQUEST_CANCELLED,
respond_with_html_bytes,
respond_with_json,
)
+from synapse.http.site import SynapseRequest
+from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.types import JsonDict
-from tests import unittest
-from tests.server import FakeChannel, ThreadedMemoryReactorClock
+from tests.server import FakeChannel, make_request
+from tests.unittest import logcontext_clean
+logger = logging.getLogger(__name__)
-class EndpointCancellationTestHelperMixin(unittest.TestCase):
- """Provides helper methods for testing cancellation of endpoints."""
- def _test_disconnect(
- self,
- reactor: ThreadedMemoryReactorClock,
- channel: FakeChannel,
- expect_cancellation: bool,
- expected_body: Union[bytes, JsonDict],
- expected_code: Optional[int] = None,
- ) -> None:
- """Disconnects an in-flight request and checks the response.
+T = TypeVar("T")
- Args:
- reactor: The twisted reactor running the request handler.
- channel: The `FakeChannel` for the request.
- expect_cancellation: `True` if request processing is expected to be
- cancelled, `False` if the request should run to completion.
- expected_body: The expected response for the request.
- expected_code: The expected status code for the request. Defaults to `200`
- or `499` depending on `expect_cancellation`.
- """
- # Determine the expected status code.
- if expected_code is None:
- if expect_cancellation:
- expected_code = HTTP_STATUS_REQUEST_CANCELLED
- else:
- expected_code = HTTPStatus.OK
-
- request = channel.request
- self.assertFalse(
- channel.is_finished(),
+
+def test_disconnect(
+ reactor: MemoryReactorClock,
+ channel: FakeChannel,
+ expect_cancellation: bool,
+ expected_body: Union[bytes, JsonDict],
+ expected_code: Optional[int] = None,
+) -> None:
+ """Disconnects an in-flight request and checks the response.
+
+ Args:
+ reactor: The twisted reactor running the request handler.
+ channel: The `FakeChannel` for the request.
+ expect_cancellation: `True` if request processing is expected to be cancelled,
+ `False` if the request should run to completion.
+ expected_body: The expected response for the request.
+ expected_code: The expected status code for the request. Defaults to `200` or
+ `499` depending on `expect_cancellation`.
+ """
+ # Determine the expected status code.
+ if expected_code is None:
+ if expect_cancellation:
+ expected_code = HTTP_STATUS_REQUEST_CANCELLED
+ else:
+ expected_code = 200
+
+ request = channel.request
+ if channel.is_finished():
+ raise AssertionError(
"Request finished before we could disconnect - "
- "was `await_result=False` passed to `make_request`?",
+ "ensure `await_result=False` is passed to `make_request`.",
)
- # We're about to disconnect the request. This also disconnects the channel, so
- # we have to rely on mocks to extract the response.
- respond_method: Callable[..., Any]
- if isinstance(expected_body, bytes):
- respond_method = respond_with_html_bytes
+ # We're about to disconnect the request. This also disconnects the channel, so we
+ # have to rely on mocks to extract the response.
+ respond_method: Callable[..., Any]
+ if isinstance(expected_body, bytes):
+ respond_method = respond_with_html_bytes
+ else:
+ respond_method = respond_with_json
+
+ with mock.patch(
+ f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
+ ) as respond_mock:
+ # Disconnect the request.
+ request.connectionLost(reason=ConnectionDone())
+
+ if expect_cancellation:
+ # An immediate cancellation is expected.
+ respond_mock.assert_called_once()
else:
- respond_method = respond_with_json
+ respond_mock.assert_not_called()
- with mock.patch(
- f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
- ) as respond_mock:
- # Disconnect the request.
- request.connectionLost(reason=ConnectionDone())
+ # The handler is expected to run to completion.
+ reactor.advance(1.0)
+ respond_mock.assert_called_once()
- if expect_cancellation:
- # An immediate cancellation is expected.
- respond_mock.assert_called_once()
- args, _kwargs = respond_mock.call_args
- code, body = args[1], args[2]
- self.assertEqual(code, expected_code)
- self.assertEqual(request.code, expected_code)
- self.assertEqual(body, expected_body)
- else:
- respond_mock.assert_not_called()
-
- # The handler is expected to run to completion.
- reactor.pump([1.0])
+ args, _kwargs = respond_mock.call_args
+ code, body = args[1], args[2]
+
+ if code != expected_code:
+ raise AssertionError(
+ f"{code} != {expected_code} : "
+ "Request did not finish with the expected status code."
+ )
+
+ if request.code != expected_code:
+ raise AssertionError(
+ f"{request.code} != {expected_code} : "
+ "Request did not finish with the expected status code."
+ )
+
+ if body != expected_body:
+ raise AssertionError(
+ f"{body!r} != {expected_body!r} : "
+ "Request did not finish with the expected status code."
+ )
+
+
+@logcontext_clean
+def make_request_with_cancellation_test(
+ test_name: str,
+ reactor: MemoryReactorClock,
+ site: Site,
+ method: str,
+ path: str,
+ content: Union[bytes, str, JsonDict] = b"",
+) -> FakeChannel:
+ """Performs a request repeatedly, disconnecting at successive `await`s, until
+ one completes.
+
+ Fails if:
+ * A logging context is lost during cancellation.
+ * A logging context get restarted after it is marked as finished, eg. if
+ a request's logging context is used by some processing started by the
+ request, but the request neglects to cancel that processing or wait for it
+ to complete.
+
+ Note that "Re-starting finished log context" errors get raised within the
+ request handling code and may or may not get caught. These errors will
+ likely manifest as a different logging context error at a later point. When
+ debugging logging context failures, setting a breakpoint in
+ `logcontext_error` can prove useful.
+ * A request gets stuck, possibly due to a previous cancellation.
+ * The request does not return a 499 when the client disconnects.
+ This implies that a `CancelledError` was swallowed somewhere.
+
+ It is up to the caller to verify that the request returns the correct data when
+ it finally runs to completion.
+
+ Note that this function can only cover a single code path and does not guarantee
+ that an endpoint is compatible with cancellation on every code path.
+ To allow inspection of the code path that is being tested, this function will
+ log the stack trace at every `await` that gets cancelled. To view these log
+ lines, `trial` can be run with the `SYNAPSE_TEST_LOG_LEVEL=INFO` environment
+ variable, which will include the log lines in `_trial_temp/test.log`.
+ Alternatively, `_log_for_request` can be modified to write to `sys.stdout`.
+
+ Args:
+ test_name: The name of the test, which will be logged.
+ reactor: The twisted reactor running the request handler.
+ site: The twisted `Site` to use to render the request.
+ method: The HTTP request method ("verb").
+ path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces and
+ such).
+ content: The body of the request.
+
+ Returns:
+ The `FakeChannel` object which stores the result of the final request that
+ runs to completion.
+ """
+ # To process a request, a coroutine run is created for the async method handling
+ # the request. That method may then start other coroutine runs, wrapped in
+ # `Deferred`s.
+ #
+ # We would like to trigger a cancellation at the first `await`, re-run the
+ # request and cancel at the second `await`, and so on. By patching
+ # `Deferred.__next__`, we can intercept `await`s, track which ones we have or
+ # have not seen, and force them to block when they wouldn't have.
+
+ # The set of previously seen `await`s.
+ # Each element is a stringified stack trace.
+ seen_awaits: Set[Tuple[str, ...]] = set()
+
+ _log_for_request(
+ 0, f"Running make_request_with_cancellation_test for {test_name}..."
+ )
+
+ for request_number in itertools.count(1):
+ deferred_patch = Deferred__next__Patch(seen_awaits, request_number)
+
+ try:
+ with mock.patch(
+ "synapse.http.server.respond_with_json", wraps=respond_with_json
+ ) as respond_mock:
+ with deferred_patch.patch():
+ # Start the request.
+ channel = make_request(
+ reactor, site, method, path, content, await_result=False
+ )
+ request = channel.request
+
+ # Run the request until we see a new `await` which we have not
+ # yet cancelled at, or it completes.
+ while not respond_mock.called and not deferred_patch.new_await_seen:
+ previous_awaits_seen = deferred_patch.awaits_seen
+
+ reactor.advance(0.0)
+
+ if deferred_patch.awaits_seen == previous_awaits_seen:
+ # We didn't see any progress. Try advancing the clock.
+ reactor.advance(1.0)
+
+ if deferred_patch.awaits_seen == previous_awaits_seen:
+ # We still didn't see any progress. The request might be
+ # stuck.
+ raise AssertionError(
+ "Request appears to be stuck, possibly due to a "
+ "previous cancelled request"
+ )
+
+ if respond_mock.called:
+ # The request ran to completion and we are done with testing it.
+
+ # `respond_with_json` writes the response asynchronously, so we
+ # might have to give the reactor a kick before the channel gets
+ # the response.
+ deferred_patch.unblock_awaits()
+ channel.await_result()
+
+ return channel
+
+ # Disconnect the client and wait for the response.
+ request.connectionLost(reason=ConnectionDone())
+
+ _log_for_request(request_number, "--- disconnected ---")
+
+ # Advance the reactor just enough to get a response.
+ # We don't want to advance the reactor too far, because we can only
+ # detect re-starts of finished logging contexts after we set the
+ # finished flag below.
+ for _ in range(2):
+ # We may need to pump the reactor to allow `delay_cancellation`s to
+ # finish.
+ if not respond_mock.called:
+ reactor.advance(0.0)
+
+ # Try advancing the clock if that didn't work.
+ if not respond_mock.called:
+ reactor.advance(1.0)
+
+ # `delay_cancellation`s may be waiting for processing that we've
+ # forced to block. Try unblocking them, followed by another round of
+ # pumping the reactor.
+ if not respond_mock.called:
+ deferred_patch.unblock_awaits()
+
+ # Mark the request's logging context as finished. If it gets
+ # activated again, an `AssertionError` will be raised and bubble up
+ # through request handling code. This `AssertionError` may or may not be
+ # caught. Eventually some other code will deactivate the logging
+ # context which will raise a different `AssertionError` because
+ # resource usage won't have been correctly tracked.
+ if isinstance(request, SynapseRequest) and request.logcontext:
+ request.logcontext.finished = True
+
+ # Check that the request finished with a 499,
+ # ie. the `CancelledError` wasn't swallowed.
respond_mock.assert_called_once()
- args, _kwargs = respond_mock.call_args
- code, body = args[1], args[2]
- self.assertEqual(code, expected_code)
- self.assertEqual(request.code, expected_code)
- self.assertEqual(body, expected_body)
+
+ if request.code != HTTP_STATUS_REQUEST_CANCELLED:
+ raise AssertionError(
+ f"{request.code} != {HTTP_STATUS_REQUEST_CANCELLED} : "
+ "Cancelled request did not finish with the correct status code."
+ )
+ finally:
+ # Unblock any processing that might be shared between requests, if we
+ # haven't already done so.
+ deferred_patch.unblock_awaits()
+
+ assert False, "unreachable" # noqa: B011
+
+
+class Deferred__next__Patch:
+ """A `Deferred.__next__` patch that will intercept `await`s and force them
+ to block once it sees a new `await`.
+
+ When done with the patch, `unblock_awaits()` must be called to clean up after any
+ `await`s that were forced to block, otherwise processing shared between multiple
+ requests, such as database queries started by `@cached`, will become permanently
+ stuck.
+
+ Usage:
+ seen_awaits = set()
+ deferred_patch = Deferred__next__Patch(seen_awaits, 1)
+ try:
+ with deferred_patch.patch():
+ # do things
+ ...
+ finally:
+ deferred_patch.unblock_awaits()
+ """
+
+ def __init__(self, seen_awaits: Set[Tuple[str, ...]], request_number: int):
+ """
+ Args:
+ seen_awaits: The set of stack traces of `await`s that have been previously
+ seen. When the `Deferred.__next__` patch sees a new `await`, it will add
+ it to the set.
+ request_number: The request number to log against.
+ """
+ self._request_number = request_number
+ self._seen_awaits = seen_awaits
+
+ self._original_Deferred___next__ = Deferred.__next__
+
+ # The number of `await`s on `Deferred`s we have seen so far.
+ self.awaits_seen = 0
+
+ # Whether we have seen a new `await` not in `seen_awaits`.
+ self.new_await_seen = False
+
+ # To force `await`s on resolved `Deferred`s to block, we make up a new
+ # unresolved `Deferred` and return it out of `Deferred.__next__` /
+ # `coroutine.send()`. We have to resolve it later, in case the `await`ing
+ # coroutine is part of some shared processing, such as `@cached`.
+ self._to_unblock: Dict[Deferred, Union[object, Failure]] = {}
+
+ # The last stack we logged.
+ self._previous_stack: List[inspect.FrameInfo] = []
+
+ def patch(self) -> ContextManager[Mock]:
+ """Returns a context manager which patches `Deferred.__next__`."""
+
+ def Deferred___next__(
+ deferred: "Deferred[T]", value: object = None
+ ) -> "Deferred[T]":
+ """Intercepts `await`s on `Deferred`s and rigs them to block once we have
+ seen enough of them.
+
+ `Deferred.__next__` will normally:
+ * return `self` if the `Deferred` is unresolved, in which case
+ `coroutine.send()` will return the `Deferred`, and
+ `_defer.inlineCallbacks` will stop running the coroutine until the
+ `Deferred` is resolved.
+ * raise a `StopIteration(result)`, containing the result of the `await`.
+ * raise another exception, which will come out of the `await`.
+ """
+ self.awaits_seen += 1
+
+ stack = _get_stack(skip_frames=1)
+ stack_hash = _hash_stack(stack)
+
+ if stack_hash not in self._seen_awaits:
+ # Block at the current `await` onwards.
+ self._seen_awaits.add(stack_hash)
+ self.new_await_seen = True
+
+ if not self.new_await_seen:
+ # This `await` isn't interesting. Let it proceed normally.
+
+ # Don't log the stack. It's been seen before in a previous run.
+ self._previous_stack = stack
+
+ return self._original_Deferred___next__(deferred, value)
+
+ # We want to block at the current `await`.
+ if deferred.called and not deferred.paused:
+ # This `Deferred` already has a result.
+ # We return a new, unresolved, `Deferred` for `_inlineCallbacks` to wait
+ # on. This blocks the coroutine that did this `await`.
+ # We queue it up for unblocking later.
+ new_deferred: "Deferred[T]" = Deferred()
+ self._to_unblock[new_deferred] = deferred.result
+
+ _log_await_stack(
+ stack,
+ self._previous_stack,
+ self._request_number,
+ "force-blocked await",
+ )
+ self._previous_stack = stack
+
+ return make_deferred_yieldable(new_deferred)
+
+ # This `Deferred` does not have a result yet.
+ # The `await` will block normally, so we don't have to do anything.
+ _log_await_stack(
+ stack,
+ self._previous_stack,
+ self._request_number,
+ "blocking await",
+ )
+ self._previous_stack = stack
+
+ return self._original_Deferred___next__(deferred, value)
+
+ return mock.patch.object(Deferred, "__next__", new=Deferred___next__)
+
+ def unblock_awaits(self) -> None:
+ """Unblocks any shared processing that we forced to block.
+
+ Must be called when done, otherwise processing shared between multiple requests,
+ such as database queries started by `@cached`, will become permanently stuck.
+ """
+ to_unblock = self._to_unblock
+ self._to_unblock = {}
+ for deferred, result in to_unblock.items():
+ deferred.callback(result)
+
+
+def _log_for_request(request_number: int, message: str) -> None:
+ """Logs a message for an iteration of `make_request_with_cancellation_test`."""
+ # We want consistent alignment when logging stack traces, so ensure the logging
+ # context has a fixed width name.
+ with LoggingContext(name=f"request-{request_number:<2}"):
+ logger.info(message)
+
+
+def _log_await_stack(
+ stack: List[inspect.FrameInfo],
+ previous_stack: List[inspect.FrameInfo],
+ request_number: int,
+ note: str,
+) -> None:
+ """Logs the stack for an `await` in `make_request_with_cancellation_test`.
+
+ Only logs the part of the stack that has changed since the previous call.
+
+ Example output looks like:
+ ```
+ delay_cancellation:750 (synapse/util/async_helpers.py:750)
+ DatabasePool._runInteraction:768 (synapse/storage/database.py:768)
+ > *blocked on await* at DatabasePool.runWithConnection:891 (synapse/storage/database.py:891)
+ ```
+
+ Args:
+ stack: The stack to log, as returned by `_get_stack()`.
+ previous_stack: The previous stack logged, with callers appearing before
+ callees.
+ request_number: The request number to log against.
+ note: A note to attach to the last stack frame, eg. "blocked on await".
+ """
+ for i, frame_info in enumerate(stack[:-1]):
+ # Skip any frames in common with the previous logging.
+ if i < len(previous_stack) and frame_info == previous_stack[i]:
+ continue
+
+ frame = _format_stack_frame(frame_info)
+ message = f"{' ' * i}{frame}"
+ _log_for_request(request_number, message)
+
+ # Always print the final frame with the `await`.
+ # If the frame with the `await` started another coroutine run, we may have already
+ # printed a deeper stack which includes our final frame. We want to log where all
+ # `await`s happen, so we reprint the frame in this case.
+ i = len(stack) - 1
+ frame_info = stack[i]
+ frame = _format_stack_frame(frame_info)
+ message = f"{' ' * i}> *{note}* at {frame}"
+ _log_for_request(request_number, message)
+
+
+def _format_stack_frame(frame_info: inspect.FrameInfo) -> str:
+ """Returns a string representation of a stack frame.
+
+ Used for debug logging.
+
+ Returns:
+ A string, formatted like
+ "JsonResource._async_render:559 (synapse/http/server.py:559)".
+ """
+ method_name = _get_stack_frame_method_name(frame_info)
+
+ return (
+ f"{method_name}:{frame_info.lineno} ({frame_info.filename}:{frame_info.lineno})"
+ )
+
+
+def _get_stack(skip_frames: int) -> List[inspect.FrameInfo]:
+ """Captures the stack for a request.
+
+ Skips any twisted frames and stops at `JsonResource.wrapped_async_request_handler`.
+
+ Used for debug logging.
+
+ Returns:
+ A list of `inspect.FrameInfo`s, with callers appearing before callees.
+ """
+ stack = []
+
+ skip_frames += 1 # Also skip `get_stack` itself.
+
+ for frame_info in inspect.stack()[skip_frames:]:
+ # Skip any twisted `inlineCallbacks` gunk.
+ if "/twisted/" in frame_info.filename:
+ continue
+
+ # Exclude the reactor frame, upwards.
+ method_name = _get_stack_frame_method_name(frame_info)
+ if method_name == "ThreadedMemoryReactorClock.advance":
+ break
+
+ stack.append(frame_info)
+
+ # Stop at `JsonResource`'s `wrapped_async_request_handler`, which is the entry
+ # point for request handling.
+ if frame_info.function == "wrapped_async_request_handler":
+ break
+
+ return stack[::-1]
+
+
+def _get_stack_frame_method_name(frame_info: inspect.FrameInfo) -> str:
+ """Returns the name of a stack frame's method.
+
+ eg. "JsonResource._async_render".
+ """
+ method_name = frame_info.function
+
+ # Prefix the class name for instance methods.
+ frame_self = frame_info.frame.f_locals.get("self")
+ if frame_self:
+ method = getattr(frame_self, method_name, None)
+ if method:
+ method_name = method.__qualname__
+ else:
+ # We couldn't find the method on `self`.
+ # Make something up. It's useful to know which class "contains" a
+ # function anyway.
+ method_name = f"{type(frame_self).__name__} {method_name}"
+
+ return method_name
+
+
+def _hash_stack(stack: List[inspect.FrameInfo]):
+ """Turns a stack into a hashable value that can be put into a set."""
+ return tuple(_format_stack_frame(frame) for frame in stack)
diff --git a/tests/http/test_fedclient.py b/tests/http/test_matrixfederationclient.py
index 006dbab093..be9eaf34e8 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_matrixfederationclient.py
@@ -617,3 +617,17 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestSendFailed)
self.assertTrue(transport.disconnecting)
+
+ def test_build_auth_headers_rejects_falsey_destinations(self) -> None:
+ with self.assertRaises(ValueError):
+ self.cl.build_auth_headers(None, b"GET", b"https://example.com")
+ with self.assertRaises(ValueError):
+ self.cl.build_auth_headers(b"", b"GET", b"https://example.com")
+ with self.assertRaises(ValueError):
+ self.cl.build_auth_headers(
+ None, b"GET", b"https://example.com", destination_is=b""
+ )
+ with self.assertRaises(ValueError):
+ self.cl.build_auth_headers(
+ b"", b"GET", b"https://example.com", destination_is=b""
+ )
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
index b3655d7b44..3cbca0f5a3 100644
--- a/tests/http/test_servlet.py
+++ b/tests/http/test_servlet.py
@@ -18,7 +18,6 @@ from typing import Tuple
from unittest.mock import Mock
from synapse.api.errors import Codes, SynapseError
-from synapse.http.server import cancellable
from synapse.http.servlet import (
RestServlet,
parse_json_object_from_request,
@@ -28,9 +27,10 @@ from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.server import HomeServer
from synapse.types import JsonDict
+from synapse.util.cancellation import cancellable
from tests import unittest
-from tests.http.server._base import EndpointCancellationTestHelperMixin
+from tests.http.server._base import test_disconnect
def make_request(content):
@@ -108,9 +108,7 @@ class CancellableRestServlet(RestServlet):
return HTTPStatus.OK, {"result": True}
-class TestRestServletCancellation(
- unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
-):
+class TestRestServletCancellation(unittest.HomeserverTestCase):
"""Tests for `RestServlet` cancellation."""
servlets = [
@@ -120,7 +118,7 @@ class TestRestServletCancellation(
def test_cancellable_disconnect(self) -> None:
"""Test that handlers with the `@cancellable` flag can be cancelled."""
channel = self.make_request("GET", "/sleep", await_result=False)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
@@ -130,7 +128,7 @@ class TestRestServletCancellation(
def test_uncancellable_disconnect(self) -> None:
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
channel = self.make_request("POST", "/sleep", await_result=False)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py
index e430941d27..0917e478a5 100644
--- a/tests/logging/test_opentracing.py
+++ b/tests/logging/test_opentracing.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import cast
+
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactorClock
@@ -23,6 +25,8 @@ from synapse.logging.context import (
from synapse.logging.opentracing import (
start_active_span,
start_active_span_follows_from,
+ tag_args,
+ trace_with_opname,
)
from synapse.util import Clock
@@ -36,10 +40,23 @@ try:
except ImportError:
jaeger_client = None # type: ignore
+import logging
+
from tests.unittest import TestCase
+logger = logging.getLogger(__name__)
+
class LogContextScopeManagerTestCase(TestCase):
+ """
+ Test logging contexts and active opentracing spans.
+
+ There's casts throughout this from generic opentracing objects (e.g.
+ opentracing.Span) to the ones specific to Jaeger since they have additional
+ properties that these tests depend on. This is safe since the only supported
+ opentracing backend is Jaeger.
+ """
+
if LogContextScopeManager is None:
skip = "Requires opentracing" # type: ignore[unreachable]
if jaeger_client is None:
@@ -50,7 +67,7 @@ class LogContextScopeManagerTestCase(TestCase):
# global variables that power opentracing. We create our own tracer instance
# and test with it.
- scope_manager = LogContextScopeManager({})
+ scope_manager = LogContextScopeManager()
config = jaeger_client.config.Config(
config={}, service_name="test", scope_manager=scope_manager
)
@@ -69,7 +86,7 @@ class LogContextScopeManagerTestCase(TestCase):
# start_active_span should start and activate a span.
scope = start_active_span("span", tracer=self._tracer)
- span = scope.span
+ span = cast(jaeger_client.Span, scope.span)
self.assertEqual(self._tracer.active_span, span)
self.assertIsNotNone(span.start_time)
@@ -91,6 +108,7 @@ class LogContextScopeManagerTestCase(TestCase):
with LoggingContext("root context"):
with start_active_span("root span", tracer=self._tracer) as root_scope:
self.assertEqual(self._tracer.active_span, root_scope.span)
+ root_context = cast(jaeger_client.SpanContext, root_scope.span.context)
scope1 = start_active_span(
"child1",
@@ -99,9 +117,8 @@ class LogContextScopeManagerTestCase(TestCase):
self.assertEqual(
self._tracer.active_span, scope1.span, "child1 was not activated"
)
- self.assertEqual(
- scope1.span.context.parent_id, root_scope.span.context.span_id
- )
+ context1 = cast(jaeger_client.SpanContext, scope1.span.context)
+ self.assertEqual(context1.parent_id, root_context.span_id)
scope2 = start_active_span_follows_from(
"child2",
@@ -109,17 +126,18 @@ class LogContextScopeManagerTestCase(TestCase):
tracer=self._tracer,
)
self.assertEqual(self._tracer.active_span, scope2.span)
- self.assertEqual(
- scope2.span.context.parent_id, scope1.span.context.span_id
- )
+ context2 = cast(jaeger_client.SpanContext, scope2.span.context)
+ self.assertEqual(context2.parent_id, context1.span_id)
with scope1, scope2:
pass
# the root scope should be restored
self.assertEqual(self._tracer.active_span, root_scope.span)
- self.assertIsNotNone(scope2.span.end_time)
- self.assertIsNotNone(scope1.span.end_time)
+ span2 = cast(jaeger_client.Span, scope2.span)
+ span1 = cast(jaeger_client.Span, scope1.span)
+ self.assertIsNotNone(span2.end_time)
+ self.assertIsNotNone(span1.end_time)
self.assertIsNone(self._tracer.active_span)
@@ -182,3 +200,80 @@ class LogContextScopeManagerTestCase(TestCase):
self._reporter.get_spans(),
[scopes[1].span, scopes[2].span, scopes[0].span],
)
+
+ def test_trace_decorator_sync(self) -> None:
+ """
+ Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
+ with sync functions
+ """
+ with LoggingContext("root context"):
+
+ @trace_with_opname("fixture_sync_func", tracer=self._tracer)
+ @tag_args
+ def fixture_sync_func() -> str:
+ return "foo"
+
+ result = fixture_sync_func()
+ self.assertEqual(result, "foo")
+
+ # the span should have been reported
+ self.assertEqual(
+ [span.operation_name for span in self._reporter.get_spans()],
+ ["fixture_sync_func"],
+ )
+
+ def test_trace_decorator_deferred(self) -> None:
+ """
+ Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
+ with functions that return deferreds
+ """
+ reactor = MemoryReactorClock()
+
+ with LoggingContext("root context"):
+
+ @trace_with_opname("fixture_deferred_func", tracer=self._tracer)
+ @tag_args
+ def fixture_deferred_func() -> "defer.Deferred[str]":
+ d1: defer.Deferred[str] = defer.Deferred()
+ d1.callback("foo")
+ return d1
+
+ result_d1 = fixture_deferred_func()
+
+ # let the tasks complete
+ reactor.pump((2,) * 8)
+
+ self.assertEqual(self.successResultOf(result_d1), "foo")
+
+ # the span should have been reported
+ self.assertEqual(
+ [span.operation_name for span in self._reporter.get_spans()],
+ ["fixture_deferred_func"],
+ )
+
+ def test_trace_decorator_async(self) -> None:
+ """
+ Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
+ with async functions
+ """
+ reactor = MemoryReactorClock()
+
+ with LoggingContext("root context"):
+
+ @trace_with_opname("fixture_async_func", tracer=self._tracer)
+ @tag_args
+ async def fixture_async_func() -> str:
+ return "foo"
+
+ d1 = defer.ensureDeferred(fixture_async_func())
+
+ # let the tasks complete
+ reactor.pump((2,) * 8)
+
+ self.assertEqual(self.successResultOf(d1), "foo")
+
+ # the span should have been reported
+ self.assertEqual(
+ [span.operation_name for span in self._reporter.get_spans()],
+ ["fixture_async_func"],
+ )
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 169e29b590..02cef6f876 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -16,6 +16,7 @@ from unittest.mock import Mock
from twisted.internet import defer
from synapse.api.constants import EduTypes, EventTypes
+from synapse.api.errors import NotFoundError
from synapse.events import EventBase
from synapse.federation.units import Transaction
from synapse.handlers.presence import UserPresenceState
@@ -29,7 +30,6 @@ from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.test_utils import simple_async_mock
from tests.test_utils.event_injection import inject_member_event
from tests.unittest import HomeserverTestCase, override_config
-from tests.utils import USE_POSTGRES_FOR_TESTS
class ModuleApiTestCase(HomeserverTestCase):
@@ -532,6 +532,34 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertEqual(res["displayname"], "simone")
self.assertIsNone(res["avatar_url"])
+ def test_update_room_membership_remote_join(self):
+ """Test that the module API can join a remote room."""
+ # Necessary to fake a remote join.
+ fake_stream_id = 1
+ mocked_remote_join = simple_async_mock(
+ return_value=("fake-event-id", fake_stream_id)
+ )
+ self.hs.get_room_member_handler()._remote_join = mocked_remote_join
+ fake_remote_host = f"{self.module_api.server_name}-remote"
+
+ # Given that the join is to be faked, we expect the relevant join event not to
+ # be persisted and the module API method to raise that.
+ self.get_failure(
+ defer.ensureDeferred(
+ self.module_api.update_room_membership(
+ sender=f"@user:{self.module_api.server_name}",
+ target=f"@user:{self.module_api.server_name}",
+ room_id=f"!nonexistent:{fake_remote_host}",
+ new_membership="join",
+ remote_room_hosts=[fake_remote_host],
+ )
+ ),
+ NotFoundError,
+ )
+
+ # Check that a remote join was attempted.
+ self.assertEqual(mocked_remote_join.call_count, 1)
+
def test_get_room_state(self):
"""Tests that a module can retrieve the state of a room through the module API."""
user_id = self.register_user("peter", "hackme")
@@ -635,15 +663,80 @@ class ModuleApiTestCase(HomeserverTestCase):
[{"set_tweak": "sound", "value": "default"}]
)
+ def test_lookup_room_alias(self) -> None:
+ """Test that modules can resolve a room alias to a room ID."""
+ password = "password"
+ user_id = self.register_user("user", password)
+ access_token = self.login(user_id, password)
+ room_alias = "my-alias"
+ reference_room_id = self.helper.create_room_as(
+ tok=access_token, extra_content={"room_alias_name": room_alias}
+ )
+ self.assertIsNotNone(reference_room_id)
+
+ (room_id, _) = self.get_success(
+ self.module_api.lookup_room_alias(
+ f"#{room_alias}:{self.module_api.server_name}"
+ )
+ )
+
+ self.assertEqual(room_id, reference_room_id)
+
+ def test_create_room(self) -> None:
+ """Test that modules can create a room."""
+ # First test user validation (i.e. user is local).
+ self.get_failure(
+ self.module_api.create_room(
+ user_id=f"@user:{self.module_api.server_name}abc",
+ config={},
+ ratelimit=False,
+ ),
+ RuntimeError,
+ )
+
+ # Now do the happy path.
+ user_id = self.register_user("user", "password")
+ access_token = self.login(user_id, "password")
+
+ room_id, room_alias = self.get_success(
+ self.module_api.create_room(
+ user_id=user_id, config={"room_alias_name": "foo-bar"}, ratelimit=False
+ )
+ )
+
+ # Check room creator.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v3/rooms/{room_id}/state/m.room.create",
+ access_token=access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.json_body["creator"], user_id)
+
+ # Check room alias.
+ self.assertEquals(room_alias, f"#foo-bar:{self.module_api.server_name}")
+
+ # Let's try a room with no alias.
+ room_id, room_alias = self.get_success(
+ self.module_api.create_room(user_id=user_id, config={}, ratelimit=False)
+ )
+
+ # Check room creator.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v3/rooms/{room_id}/state/m.room.create",
+ access_token=access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.json_body["creator"], user_id)
+
+ # Check room alias.
+ self.assertIsNone(room_alias)
+
class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
"""For testing ModuleApi functionality in a multi-worker setup"""
- # Testing stream ID replication from the main to worker processes requires postgres
- # (due to needing `MultiWriterIdGenerator`).
- if not USE_POSTGRES_FOR_TESTS:
- skip = "Requires Postgres"
-
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -653,7 +746,6 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
def default_config(self):
conf = super().default_config()
- conf["redis"] = {"enabled": "true"}
conf["stream_writers"] = {"presence": ["presence_writer"]}
conf["instance_map"] = {
"presence_writer": {"host": "testserv", "port": 1001},
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index ba158f5d93..d9c68cdd2d 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -577,7 +577,7 @@ class HTTPPusherTests(HomeserverTestCase):
# Carry out our option-value specific test
#
# This push should still only contain an unread count of 1 (for 1 unread room)
- self._check_push_attempt(6, 1)
+ self._check_push_attempt(7, 1)
@override_config({"push": {"group_unread_count_by_room": False}})
def test_push_unread_count_message_count(self) -> None:
@@ -591,7 +591,7 @@ class HTTPPusherTests(HomeserverTestCase):
#
# We're counting every unread message, so there should now be 3 since the
# last read receipt
- self._check_push_attempt(6, 3)
+ self._check_push_attempt(7, 3)
def _test_push_unread_count(self) -> None:
"""
@@ -641,18 +641,18 @@ class HTTPPusherTests(HomeserverTestCase):
response = self.helper.send(
room_id, body="Hello there!", tok=other_access_token
)
- # To get an unread count, the user who is getting notified has to have a read
- # position in the room. We'll set the read position to this event in a moment
+
first_message_event_id = response["event_id"]
expected_push_attempts = 1
- self._check_push_attempt(expected_push_attempts, 0)
+ self._check_push_attempt(expected_push_attempts, 1)
self._send_read_request(access_token, first_message_event_id, room_id)
- # Unread count has not changed. Therefore, ensure that read request does not
- # trigger a push notification.
- self.assertEqual(len(self.push_attempts), 1)
+ # Unread count has changed. Therefore, ensure that read request triggers
+ # a push notification.
+ expected_push_attempts += 1
+ self.assertEqual(len(self.push_attempts), expected_push_attempts)
# Send another message
response2 = self.helper.send(
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 9b623d0033..718f489577 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -16,13 +16,23 @@ from typing import Dict, Optional, Set, Tuple, Union
import frozendict
+from twisted.test.proto_helpers import MemoryReactor
+
+import synapse.rest.admin
+from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
+from synapse.appservice import ApplicationService
from synapse.events import FrozenEvent
from synapse.push import push_rule_evaluator
from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent
+from synapse.rest.client import login, register, room
+from synapse.server import HomeServer
+from synapse.storage.databases.main.appservice import _make_exclusive_regex
from synapse.types import JsonDict
+from synapse.util import Clock
from tests import unittest
+from tests.test_utils.event_injection import create_event, inject_member_event
class PushRuleEvaluatorTestCase(unittest.TestCase):
@@ -354,3 +364,78 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"event_type": "*.reaction",
}
self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
+
+
+class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
+ """Tests for the bulk push rule evaluator"""
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ register.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
+ # Define an application service so that we can register appservice users
+ self._service_token = "some_token"
+ self._service = ApplicationService(
+ self._service_token,
+ "as1",
+ "@as.sender:test",
+ namespaces={
+ "users": [
+ {"regex": "@_as_.*:test", "exclusive": True},
+ {"regex": "@as.sender:test", "exclusive": True},
+ ]
+ },
+ msc3202_transaction_extensions=True,
+ )
+ self.hs.get_datastores().main.services_cache = [self._service]
+ self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
+ [self._service]
+ )
+
+ self._as_user, _ = self.register_appservice_user(
+ "_as_user", self._service_token
+ )
+
+ self.evaluator = self.hs.get_bulk_push_rule_evaluator()
+
+ def test_ignore_appservice_users(self) -> None:
+ "Test that we don't generate push for appservice users"
+
+ user_id = self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ room_id = self.helper.create_room_as(user_id, tok=token)
+ self.get_success(
+ inject_member_event(self.hs, room_id, self._as_user, Membership.JOIN)
+ )
+
+ event, context = self.get_success(
+ create_event(
+ self.hs,
+ type=EventTypes.Message,
+ room_id=room_id,
+ sender=user_id,
+ content={"body": "test", "msgtype": "m.text"},
+ )
+ )
+
+ # Assert the returned push rules do not contain the app service user
+ rules = self.get_success(self.evaluator._get_rules_for_event(event))
+ self.assertTrue(self._as_user not in rules)
+
+ # Assert that no push actions have been added to the staging table (the
+ # sender should not be pushed for the event)
+ users_with_push_actions = self.get_success(
+ self.hs.get_datastores().main.db_pool.simple_select_onecol(
+ table="event_push_actions_staging",
+ keyvalues={"event_id": event.event_id},
+ retcol="user_id",
+ desc="test_ignore_appservice_users",
+ )
+ )
+
+ self.assertEqual(len(users_with_push_actions), 0)
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 970d5e533b..ce53f808db 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -24,11 +24,11 @@ from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
-from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
-from synapse.replication.tcp.resource import (
- ReplicationStreamProtocolFactory,
+from synapse.replication.tcp.protocol import (
+ ClientReplicationStreamProtocol,
ServerReplicationStreamProtocol,
)
+from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.server import HomeServer
from tests import unittest
@@ -220,15 +220,34 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests running multiple workers.
+ Enables Redis, providing a fake Redis server.
+
Automatically handle HTTP replication requests from workers to master,
unlike `BaseStreamTestCase`.
"""
+ if not hiredis:
+ skip = "Requires hiredis"
+
+ if not USE_POSTGRES_FOR_TESTS:
+ # Redis replication only takes place on Postgres
+ skip = "Requires Postgres"
+
+ def default_config(self) -> Dict[str, Any]:
+ """
+ Overrides the default config to enable Redis.
+ Even if the test only uses make_worker_hs, the main process needs Redis
+ enabled otherwise it won't create a Fake Redis server to listen on the
+ Redis port and accept fake TCP connections.
+ """
+ base = super().default_config()
+ base["redis"] = {"enabled": True}
+ return base
+
def setUp(self):
super().setUp()
# build a replication server
- self.server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = self.hs.get_replication_streamer()
# Fake in memory Redis server that servers can connect to.
@@ -247,15 +266,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# handling inbound HTTP requests to that instance.
self._hs_to_site = {self.hs: self.site}
- if self.hs.config.redis.redis_enabled:
- # Handle attempts to connect to fake redis server.
- self.reactor.add_tcp_client_callback(
- "localhost",
- 6379,
- self.connect_any_redis_attempts,
- )
+ # Handle attempts to connect to fake redis server.
+ self.reactor.add_tcp_client_callback(
+ "localhost",
+ 6379,
+ self.connect_any_redis_attempts,
+ )
- self.hs.get_replication_command_handler().start_replication(self.hs)
+ self.hs.get_replication_command_handler().start_replication(self.hs)
# When we see a connection attempt to the master replication listener we
# automatically set up the connection. This is so that tests don't
@@ -339,27 +357,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
store = worker_hs.get_datastores().main
store.db_pool._db_pool = self.database_pool._db_pool
- # Set up TCP replication between master and the new worker if we don't
- # have Redis support enabled.
- if not worker_hs.config.redis.redis_enabled:
- repl_handler = ReplicationCommandHandler(worker_hs)
- client = ClientReplicationStreamProtocol(
- worker_hs,
- "client",
- "test",
- self.clock,
- repl_handler,
- )
- server = self.server_factory.buildProtocol(
- IPv4Address("TCP", "127.0.0.1", 0)
- )
-
- client_transport = FakeTransport(server, self.reactor)
- client.makeConnection(client_transport)
-
- server_transport = FakeTransport(client, self.reactor)
- server.makeConnection(server_transport)
-
# Set up a resource for the worker
resource = ReplicationRestResource(worker_hs)
@@ -378,8 +375,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
reactor=self.reactor,
)
- if worker_hs.config.redis.redis_enabled:
- worker_hs.get_replication_command_handler().start_replication(worker_hs)
+ worker_hs.get_replication_command_handler().start_replication(worker_hs)
return worker_hs
@@ -582,27 +578,3 @@ class FakeRedisPubSubProtocol(Protocol):
def connectionLost(self, reason):
self._server.remove_subscriber(self)
-
-
-class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase):
- """
- A test case that enables Redis, providing a fake Redis server.
- """
-
- if not hiredis:
- skip = "Requires hiredis"
-
- if not USE_POSTGRES_FOR_TESTS:
- # Redis replication only takes place on Postgres
- skip = "Requires Postgres"
-
- def default_config(self) -> Dict[str, Any]:
- """
- Overrides the default config to enable Redis.
- Even if the test only uses make_worker_hs, the main process needs Redis
- enabled otherwise it won't create a Fake Redis server to listen on the
- Redis port and accept fake TCP connections.
- """
- base = super().default_config()
- base["redis"] = {"enabled": True}
- return base
diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py
index a5ab093a27..936ab4504a 100644
--- a/tests/replication/http/test__base.py
+++ b/tests/replication/http/test__base.py
@@ -18,14 +18,15 @@ from typing import Tuple
from twisted.web.server import Request
from synapse.api.errors import Codes
-from synapse.http.server import JsonResource, cancellable
+from synapse.http.server import JsonResource
from synapse.replication.http import REPLICATION_PREFIX
from synapse.replication.http._base import ReplicationEndpoint
from synapse.server import HomeServer
from synapse.types import JsonDict
+from synapse.util.cancellation import cancellable
from tests import unittest
-from tests.http.server._base import EndpointCancellationTestHelperMixin
+from tests.http.server._base import test_disconnect
class CancellableReplicationEndpoint(ReplicationEndpoint):
@@ -69,9 +70,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
return HTTPStatus.OK, {"result": True}
-class ReplicationEndpointCancellationTestCase(
- unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
-):
+class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
"""Tests for `ReplicationEndpoint` cancellation."""
def create_test_resource(self):
@@ -87,7 +86,7 @@ class ReplicationEndpointCancellationTestCase(
"""Test that handlers with the `@cancellable` flag can be cancelled."""
path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/"
channel = self.make_request("POST", path, await_result=False)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
@@ -98,7 +97,7 @@ class ReplicationEndpointCancellationTestCase(
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/"
channel = self.make_request("POST", path, await_result=False)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
diff --git a/tests/replication/slave/storage/test_account_data.py b/tests/replication/slave/storage/test_account_data.py
deleted file mode 100644
index 1524087c43..0000000000
--- a/tests/replication/slave/storage/test_account_data.py
+++ /dev/null
@@ -1,42 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
-
-from ._base import BaseSlavedStoreTestCase
-
-USER_ID = "@feeling:blue"
-TYPE = "my.type"
-
-
-class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase):
-
- STORE_TYPE = SlavedAccountDataStore
-
- def test_user_account_data(self):
- self.get_success(
- self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1})
- )
- self.replicate()
- self.check(
- "get_global_account_data_by_type_for_user", [USER_ID, TYPE], {"a": 1}
- )
-
- self.get_success(
- self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2})
- )
- self.replicate()
- self.check(
- "get_global_account_data_by_type_for_user", [USER_ID, TYPE], {"a": 2}
- )
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 6d3d4afe52..531a0db2d0 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -15,7 +15,9 @@ import logging
from typing import Iterable, Optional
from canonicaljson import encode_canonical_json
+from parameterized import parameterized
+from synapse.api.constants import ReceiptTypes
from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
from synapse.handlers.room import RoomEventSource
@@ -156,17 +158,26 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
],
)
- def test_push_actions_for_user(self):
+ @parameterized.expand([(True,), (False,)])
+ def test_push_actions_for_user(self, send_receipt: bool):
self.persist(type="m.room.create", key="", creator=USER_ID)
- self.persist(type="m.room.join", key=USER_ID, membership="join")
+ self.persist(type="m.room.member", key=USER_ID, membership="join")
self.persist(
- type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join"
+ type="m.room.member", sender=USER_ID, key=USER_ID_2, membership="join"
)
event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello")
self.replicate()
+
+ if send_receipt:
+ self.get_success(
+ self.master_store.insert_receipt(
+ ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], {}
+ )
+ )
+
self.check(
"get_unread_event_push_actions_by_room_for_user",
- [ROOM_ID, USER_ID_2, event1.event_id],
+ [ROOM_ID, USER_ID_2],
NotifCounts(highlight_count=0, unread_count=0, notify_count=0),
)
@@ -179,7 +190,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.replicate()
self.check(
"get_unread_event_push_actions_by_room_for_user",
- [ROOM_ID, USER_ID_2, event1.event_id],
+ [ROOM_ID, USER_ID_2],
NotifCounts(highlight_count=0, unread_count=0, notify_count=1),
)
@@ -194,7 +205,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.replicate()
self.check(
"get_unread_event_push_actions_by_room_for_user",
- [ROOM_ID, USER_ID_2, event1.event_id],
+ [ROOM_ID, USER_ID_2],
NotifCounts(highlight_count=1, unread_count=0, notify_count=2),
)
diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
index e6a19eafd5..1e299d2d67 100644
--- a/tests/replication/tcp/test_handler.py
+++ b/tests/replication/tcp/test_handler.py
@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from tests.replication._base import RedisMultiWorkerStreamTestCase
+from tests.replication._base import BaseMultiWorkerStreamTestCase
-class ChannelsTestCase(RedisMultiWorkerStreamTestCase):
+class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
def test_subscribed_to_enough_redis_channels(self) -> None:
# The default main process is subscribed to the USER_IP channel.
self.assertCountEqual(
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index a7ca68069e..541d390286 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -20,7 +20,6 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
-from tests.utils import USE_POSTGRES_FOR_TESTS
logger = logging.getLogger(__name__)
@@ -28,11 +27,6 @@ logger = logging.getLogger(__name__)
class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
"""Checks event persisting sharding works"""
- # Event persister sharding requires postgres (due to needing
- # `MultiWriterIdGenerator`).
- if not USE_POSTGRES_FOR_TESTS:
- skip = "Requires Postgres"
-
servlets = [
admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -50,7 +44,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
def default_config(self):
conf = super().default_config()
- conf["redis"] = {"enabled": "true"}
conf["stream_writers"] = {"events": ["worker1", "worker2"]}
conf["instance_map"] = {
"worker1": {"host": "testserv", "port": 1001},
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 82ac5991e6..a8f6436836 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -13,7 +13,6 @@
# limitations under the License.
import urllib.parse
-from http import HTTPStatus
from parameterized import parameterized
@@ -42,7 +41,7 @@ class VersionTestCase(unittest.HomeserverTestCase):
def test_version_string(self) -> None:
channel = self.make_request("GET", self.url, shorthand=False)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
{"server_version", "python_version"}, set(channel.json_body.keys())
)
@@ -79,10 +78,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Should be quarantined
self.assertEqual(
- HTTPStatus.NOT_FOUND,
+ 404,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.NOT_FOUND on accessing quarantined media: %s"
+ "Expected to receive a 404 on accessing quarantined media: %s"
% server_and_media_id
),
)
@@ -107,7 +106,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Expect a forbidden error
self.assertEqual(
- HTTPStatus.FORBIDDEN,
+ 403,
channel.code,
msg="Expected forbidden on quarantining media as a non-admin",
)
@@ -139,7 +138,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
)
# Should be successful
- self.assertEqual(HTTPStatus.OK, channel.code)
+ self.assertEqual(200, channel.code)
# Quarantine the media
url = "/_synapse/admin/v1/media/quarantine/%s/%s" % (
@@ -152,7 +151,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
access_token=admin_user_tok,
)
self.pump(1.0)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Attempt to access the media
self._ensure_quarantined(admin_user_tok, server_name_and_media_id)
@@ -209,7 +208,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
access_token=admin_user_tok,
)
self.pump(1.0)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items"
)
@@ -251,7 +250,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
access_token=admin_user_tok,
)
self.pump(1.0)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items"
)
@@ -285,7 +284,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),)
channel = self.make_request("POST", url, access_token=admin_user_tok)
self.pump(1.0)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Quarantine all media by this user
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
@@ -297,7 +296,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
access_token=admin_user_tok,
)
self.pump(1.0)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
channel.json_body, {"num_quarantined": 1}, "Expected 1 quarantined item"
)
@@ -318,10 +317,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Shouldn't be quarantined
self.assertEqual(
- HTTPStatus.OK,
+ 200,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.OK on accessing not-quarantined media: %s"
+ "Expected to receive a 200 on accessing not-quarantined media: %s"
% server_and_media_id_2
),
)
@@ -350,7 +349,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase):
def test_purge_history(self) -> None:
"""
Simple test of purge history API.
- Test only that is is possible to call, get status HTTPStatus.OK and purge_id.
+ Test only that is is possible to call, get status 200 and purge_id.
"""
channel = self.make_request(
@@ -360,7 +359,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("purge_id", channel.json_body)
purge_id = channel.json_body["purge_id"]
@@ -371,5 +370,5 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("complete", channel.json_body["status"])
diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index 6cf56b1e35..d507a3af8d 100644
--- a/tests/rest/admin/test_background_updates.py
+++ b/tests/rest/admin/test_background_updates.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
from typing import Collection
from parameterized import parameterized
@@ -51,7 +50,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
)
def test_requester_is_no_admin(self, method: str, url: str) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
self.register_user("user", "pass", admin=False)
@@ -64,7 +63,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
access_token=other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -81,7 +80,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
# job_name invalid
@@ -92,7 +91,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
def _register_bg_update(self) -> None:
@@ -125,7 +124,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Background updates should be enabled, but none should be running.
self.assertDictEqual(
@@ -147,7 +146,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Background updates should be enabled, and one should be running.
self.assertDictEqual(
@@ -181,7 +180,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/enabled",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertDictEqual(channel.json_body, {"enabled": True})
# Disable the BG updates
@@ -191,7 +190,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
content={"enabled": False},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertDictEqual(channel.json_body, {"enabled": False})
# Advance a bit and get the current status, note this will finish the in
@@ -204,7 +203,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertDictEqual(
channel.json_body,
{
@@ -231,7 +230,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# There should be no change from the previous /status response.
self.assertDictEqual(
@@ -259,7 +258,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
content={"enabled": True},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertDictEqual(channel.json_body, {"enabled": True})
@@ -270,7 +269,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Background updates should be enabled and making progress.
self.assertDictEqual(
@@ -325,7 +324,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# test that each background update is waiting now
for update in updates:
@@ -365,4 +364,4 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index f7080bda87..d52aee8f92 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import urllib.parse
-from http import HTTPStatus
from parameterized import parameterized
@@ -58,7 +57,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(method, self.url, b"{}")
self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
+ 401,
channel.code,
msg=channel.json_body,
)
@@ -76,7 +75,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- HTTPStatus.FORBIDDEN,
+ 403,
channel.code,
msg=channel.json_body,
)
@@ -85,7 +84,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
@parameterized.expand(["GET", "PUT", "DELETE"])
def test_user_does_not_exist(self, method: str) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = (
"/_synapse/admin/v2/users/@unknown_person:test/devices/%s"
@@ -98,13 +97,13 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@parameterized.expand(["GET", "PUT", "DELETE"])
def test_user_is_not_local(self, method: str) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = (
"/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices/%s"
@@ -117,12 +116,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
def test_unknown_device(self) -> None:
"""
- Tests that a lookup for a device that does not exist returns either HTTPStatus.NOT_FOUND or HTTPStatus.OK.
+ Tests that a lookup for a device that does not exist returns either 404 or 200.
"""
url = "/_synapse/admin/v2/users/%s/devices/unknown_device" % urllib.parse.quote(
self.other_user
@@ -134,7 +133,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
channel = self.make_request(
@@ -143,7 +142,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
channel = self.make_request(
"DELETE",
@@ -151,8 +150,8 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- # Delete unknown device returns status HTTPStatus.OK
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ # Delete unknown device returns status 200
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def test_update_device_too_long_display_name(self) -> None:
"""
@@ -179,7 +178,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
content=update,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"])
# Ensure the display name was not updated.
@@ -189,12 +188,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new display", channel.json_body["display_name"])
def test_update_no_display_name(self) -> None:
"""
- Tests that a update for a device without JSON returns a HTTPStatus.OK
+ Tests that a update for a device without JSON returns a 200
"""
# Set iniital display name.
update = {"display_name": "new display"}
@@ -210,7 +209,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Ensure the display name was not updated.
channel = self.make_request(
@@ -219,7 +218,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new display", channel.json_body["display_name"])
def test_update_display_name(self) -> None:
@@ -234,7 +233,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
content={"display_name": "new displayname"},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check new display_name
channel = self.make_request(
@@ -243,7 +242,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new displayname", channel.json_body["display_name"])
def test_get_device(self) -> None:
@@ -256,7 +255,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
# Check that all fields are available
self.assertIn("user_id", channel.json_body)
@@ -281,7 +280,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Ensure that the number of devices is decreased
res = self.get_success(self.handler.get_devices_by_user(self.other_user))
@@ -312,7 +311,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
+ 401,
channel.code,
msg=channel.json_body,
)
@@ -331,7 +330,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- HTTPStatus.FORBIDDEN,
+ 403,
channel.code,
msg=channel.json_body,
)
@@ -339,7 +338,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
def test_user_does_not_exist(self) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v2/users/@unknown_person:test/devices"
channel = self.make_request(
@@ -348,12 +347,12 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices"
@@ -363,7 +362,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
def test_user_has_no_devices(self) -> None:
@@ -379,7 +378,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["devices"]))
@@ -399,7 +398,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_devices, channel.json_body["total"])
self.assertEqual(number_devices, len(channel.json_body["devices"]))
self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"])
@@ -438,7 +437,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, b"{}")
self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
+ 401,
channel.code,
msg=channel.json_body,
)
@@ -457,7 +456,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- HTTPStatus.FORBIDDEN,
+ 403,
channel.code,
msg=channel.json_body,
)
@@ -465,7 +464,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
def test_user_does_not_exist(self) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices"
channel = self.make_request(
@@ -474,12 +473,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices"
@@ -489,12 +488,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
def test_unknown_devices(self) -> None:
"""
- Tests that a remove of a device that does not exist returns HTTPStatus.OK.
+ Tests that a remove of a device that does not exist returns 200.
"""
channel = self.make_request(
"POST",
@@ -503,8 +502,8 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
content={"devices": ["unknown_device1", "unknown_device2"]},
)
- # Delete unknown devices returns status HTTPStatus.OK
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ # Delete unknown devices returns status 200
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def test_delete_devices(self) -> None:
"""
@@ -533,7 +532,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
content={"devices": device_ids},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
res = self.get_success(self.handler.get_devices_by_user(self.other_user))
self.assertEqual(0, len(res))
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index 4f89f8b534..8a4e5c3f77 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
from typing import List
from twisted.test.proto_helpers import MemoryReactor
@@ -81,16 +80,12 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
@@ -99,11 +94,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_default_success(self) -> None:
@@ -117,7 +108,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
self.assertNotIn("next_token", channel.json_body)
@@ -134,7 +125,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 5)
self.assertEqual(channel.json_body["next_token"], 5)
@@ -151,7 +142,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 15)
self.assertNotIn("next_token", channel.json_body)
@@ -168,7 +159,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(channel.json_body["next_token"], 15)
self.assertEqual(len(channel.json_body["event_reports"]), 10)
@@ -185,7 +176,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["event_reports"]), 10)
self.assertNotIn("next_token", channel.json_body)
@@ -205,7 +196,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["event_reports"]), 10)
self.assertNotIn("next_token", channel.json_body)
@@ -225,7 +216,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 5)
self.assertEqual(len(channel.json_body["event_reports"]), 5)
self.assertNotIn("next_token", channel.json_body)
@@ -247,7 +238,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
report = 1
@@ -265,7 +256,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
report = 1
@@ -278,7 +269,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
def test_invalid_search_order(self) -> None:
"""
- Testing that a invalid search order returns a HTTPStatus.BAD_REQUEST
+ Testing that a invalid search order returns a 400
"""
channel = self.make_request(
@@ -287,17 +278,13 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual("Unknown direction: bar", channel.json_body["error"])
def test_limit_is_negative(self) -> None:
"""
- Testing that a negative limit parameter returns a HTTPStatus.BAD_REQUEST
+ Testing that a negative limit parameter returns a 400
"""
channel = self.make_request(
@@ -306,16 +293,12 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_from_is_negative(self) -> None:
"""
- Testing that a negative from parameter returns a HTTPStatus.BAD_REQUEST
+ Testing that a negative from parameter returns a 400
"""
channel = self.make_request(
@@ -324,11 +307,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_next_token(self) -> None:
@@ -344,7 +323,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
self.assertNotIn("next_token", channel.json_body)
@@ -357,7 +336,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
self.assertNotIn("next_token", channel.json_body)
@@ -370,7 +349,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 19)
self.assertEqual(channel.json_body["next_token"], 19)
@@ -384,7 +363,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 1)
self.assertNotIn("next_token", channel.json_body)
@@ -400,7 +379,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
{"score": -100, "reason": "this makes me sad"},
access_token=user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def _create_event_and_report_without_parameters(
self, room_id: str, user_tok: str
@@ -415,7 +394,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
{},
access_token=user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def _check_fields(self, content: List[JsonDict]) -> None:
"""Checks that all attributes are present in an event report"""
@@ -431,6 +410,33 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.assertIn("score", c)
self.assertIn("reason", c)
+ def test_count_correct_despite_table_deletions(self) -> None:
+ """
+ Tests that the count matches the number of rows, even if rows in joined tables
+ are missing.
+ """
+
+ # Delete rows from room_stats_state for one of our rooms.
+ self.get_success(
+ self.hs.get_datastores().main.db_pool.simple_delete(
+ "room_stats_state", {"room_id": self.room_id1}, desc="_"
+ )
+ )
+
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ # The 'total' field is 10 because only 10 reports will actually
+ # be retrievable since we deleted the rows in the room_stats_state
+ # table.
+ self.assertEqual(channel.json_body["total"], 10)
+ # This is consistent with the number of rows actually returned.
+ self.assertEqual(len(channel.json_body["event_reports"]), 10)
+
class EventReportDetailTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -466,16 +472,12 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
@@ -484,11 +486,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_default_success(self) -> None:
@@ -502,12 +500,12 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self._check_fields(channel.json_body)
def test_invalid_report_id(self) -> None:
"""
- Testing that an invalid `report_id` returns a HTTPStatus.BAD_REQUEST.
+ Testing that an invalid `report_id` returns a 400.
"""
# `report_id` is negative
@@ -517,11 +515,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"The report_id parameter must be a string representing a positive integer.",
@@ -535,11 +529,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"The report_id parameter must be a string representing a positive integer.",
@@ -553,11 +543,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"The report_id parameter must be a string representing a positive integer.",
@@ -566,7 +552,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
def test_report_id_not_found(self) -> None:
"""
- Testing that a not existing `report_id` returns a HTTPStatus.NOT_FOUND.
+ Testing that a not existing `report_id` returns a 404.
"""
channel = self.make_request(
@@ -575,11 +561,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.NOT_FOUND,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
self.assertEqual("Event report not found", channel.json_body["error"])
@@ -594,7 +576,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
{"score": -100, "reason": "this makes me sad"},
access_token=user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def _check_fields(self, content: JsonDict) -> None:
"""Checks that all attributes are present in a event report"""
diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index 929bbdc37d..4c7864c629 100644
--- a/tests/rest/admin/test_federation.py
+++ b/tests/rest/admin/test_federation.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
from typing import List, Optional
from parameterized import parameterized
@@ -64,7 +63,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -77,7 +76,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -87,7 +86,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# unkown order_by
@@ -97,7 +96,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
@@ -107,7 +106,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid destination
@@ -117,7 +116,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
# invalid destination
@@ -127,7 +126,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_limit(self) -> None:
@@ -142,7 +141,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(len(channel.json_body["destinations"]), 5)
self.assertEqual(channel.json_body["next_token"], "5")
@@ -160,7 +159,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(len(channel.json_body["destinations"]), 15)
self.assertNotIn("next_token", channel.json_body)
@@ -178,7 +177,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(channel.json_body["next_token"], "15")
self.assertEqual(len(channel.json_body["destinations"]), 10)
@@ -198,7 +197,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(len(channel.json_body["destinations"]), number_destinations)
self.assertNotIn("next_token", channel.json_body)
@@ -211,7 +210,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(len(channel.json_body["destinations"]), number_destinations)
self.assertNotIn("next_token", channel.json_body)
@@ -224,7 +223,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(len(channel.json_body["destinations"]), 19)
self.assertEqual(channel.json_body["next_token"], "19")
@@ -238,7 +237,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(len(channel.json_body["destinations"]), 1)
self.assertNotIn("next_token", channel.json_body)
@@ -255,7 +254,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_destinations, len(channel.json_body["destinations"]))
self.assertEqual(number_destinations, channel.json_body["total"])
@@ -290,7 +289,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_destination_list))
returned_order = [
@@ -376,7 +375,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that destinations were returned
self.assertTrue("destinations" in channel.json_body)
@@ -418,7 +417,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("sub0.example.com", channel.json_body["destination"])
# Check that all fields are available
@@ -435,7 +434,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("sub0.example.com", channel.json_body["destination"])
self.assertEqual(0, channel.json_body["retry_last_ts"])
self.assertEqual(0, channel.json_body["retry_interval"])
@@ -452,7 +451,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
retry_timings = self.get_success(
self.store.get_destination_retry_timings("sub0.example.com")
@@ -469,7 +468,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"The retry timing does not need to be reset for this destination.",
channel.json_body["error"],
@@ -561,7 +560,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -574,7 +573,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -584,7 +583,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
@@ -594,7 +593,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid destination
@@ -604,7 +603,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_limit(self) -> None:
@@ -619,7 +618,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(len(channel.json_body["rooms"]), 3)
self.assertEqual(channel.json_body["next_token"], "3")
@@ -637,7 +636,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(len(channel.json_body["rooms"]), 5)
self.assertNotIn("next_token", channel.json_body)
@@ -655,7 +654,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(channel.json_body["next_token"], "8")
self.assertEqual(len(channel.json_body["rooms"]), 5)
@@ -673,7 +672,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel_asc.code, msg=channel_asc.json_body)
+ self.assertEqual(200, channel_asc.code, msg=channel_asc.json_body)
self.assertEqual(channel_asc.json_body["total"], number_rooms)
self.assertEqual(number_rooms, len(channel_asc.json_body["rooms"]))
self._check_fields(channel_asc.json_body["rooms"])
@@ -685,7 +684,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel_desc.code, msg=channel_desc.json_body)
+ self.assertEqual(200, channel_desc.code, msg=channel_desc.json_body)
self.assertEqual(channel_desc.json_body["total"], number_rooms)
self.assertEqual(number_rooms, len(channel_desc.json_body["rooms"]))
self._check_fields(channel_desc.json_body["rooms"])
@@ -711,7 +710,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(len(channel.json_body["rooms"]), number_rooms)
self.assertNotIn("next_token", channel.json_body)
@@ -724,7 +723,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(len(channel.json_body["rooms"]), number_rooms)
self.assertNotIn("next_token", channel.json_body)
@@ -737,7 +736,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(len(channel.json_body["rooms"]), 4)
self.assertEqual(channel.json_body["next_token"], "4")
@@ -751,7 +750,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(len(channel.json_body["rooms"]), 1)
self.assertNotIn("next_token", channel.json_body)
@@ -767,7 +766,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(number_rooms, len(channel.json_body["rooms"]))
self._check_fields(channel.json_body["rooms"])
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index e909e444ac..aadb31ca83 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
-from http import HTTPStatus
from parameterized import parameterized
@@ -60,7 +59,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
channel = self.make_request("DELETE", url, b"{}")
self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
+ 401,
channel.code,
msg=channel.json_body,
)
@@ -81,16 +80,12 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_media_does_not_exist(self) -> None:
"""
- Tests that a lookup for a media that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a media that does not exist returns a 404
"""
url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
@@ -100,12 +95,12 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_media_is_not_local(self) -> None:
"""
- Tests that a lookup for a media that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a media that is not a local returns a 400
"""
url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345")
@@ -115,7 +110,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only delete local media", channel.json_body["error"])
def test_delete_media(self) -> None:
@@ -131,7 +126,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
upload_resource,
SMALL_PNG,
tok=self.admin_user_tok,
- expect_code=HTTPStatus.OK,
+ expect_code=200,
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -151,11 +146,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
# Should be successful
self.assertEqual(
- HTTPStatus.OK,
+ 200,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.OK on accessing media: %s"
- % server_and_media_id
+ "Expected to receive a 200 on accessing media: %s" % server_and_media_id
),
)
@@ -172,7 +166,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
media_id,
@@ -189,10 +183,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
self.assertEqual(
- HTTPStatus.NOT_FOUND,
+ 404,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s"
+ "Expected to receive a 404 on accessing deleted media: %s"
% server_and_media_id
),
)
@@ -231,11 +225,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
@@ -251,16 +241,12 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_media_is_not_local(self) -> None:
"""
- Tests that a lookup for media that is not local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for media that is not local returns a 400
"""
url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain"
@@ -270,7 +256,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only delete local media", channel.json_body["error"])
def test_missing_parameter(self) -> None:
@@ -283,11 +269,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Missing integer query parameter 'before_ts'", channel.json_body["error"]
@@ -303,11 +285,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter before_ts must be a positive integer.",
@@ -320,11 +298,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter before_ts you provided is from the year 1970. "
@@ -338,11 +312,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter size_gt must be a string representing a positive integer.",
@@ -355,11 +325,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Boolean query parameter 'keep_profiles' must be one of ['true', 'false']",
@@ -388,7 +354,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
media_id,
@@ -413,7 +379,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self._access_media(server_and_media_id)
@@ -425,7 +391,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
server_and_media_id.split("/")[1],
@@ -449,7 +415,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&size_gt=67",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self._access_media(server_and_media_id)
@@ -460,7 +426,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&size_gt=66",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
server_and_media_id.split("/")[1],
@@ -485,7 +451,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
content={"avatar_url": "mxc://%s" % (server_and_media_id,)},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
now_ms = self.clock.time_msec()
channel = self.make_request(
@@ -493,7 +459,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self._access_media(server_and_media_id)
@@ -504,7 +470,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
server_and_media_id.split("/")[1],
@@ -530,7 +496,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
content={"url": "mxc://%s" % (server_and_media_id,)},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
now_ms = self.clock.time_msec()
channel = self.make_request(
@@ -538,7 +504,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self._access_media(server_and_media_id)
@@ -549,7 +515,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
server_and_media_id.split("/")[1],
@@ -569,7 +535,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
upload_resource,
SMALL_PNG,
tok=self.admin_user_tok,
- expect_code=HTTPStatus.OK,
+ expect_code=200,
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -602,10 +568,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
if expect_success:
self.assertEqual(
- HTTPStatus.OK,
+ 200,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.OK on accessing media: %s"
+ "Expected to receive a 200 on accessing media: %s"
% server_and_media_id
),
)
@@ -613,10 +579,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.assertTrue(os.path.exists(local_path))
else:
self.assertEqual(
- HTTPStatus.NOT_FOUND,
+ 404,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s"
+ "Expected to receive a 404 on accessing deleted media: %s"
% (server_and_media_id)
),
)
@@ -648,7 +614,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
upload_resource,
SMALL_PNG,
tok=self.admin_user_tok,
- expect_code=HTTPStatus.OK,
+ expect_code=200,
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -668,11 +634,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
b"{}",
)
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["quarantine", "unquarantine"])
@@ -689,11 +651,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_quarantine_media(self) -> None:
@@ -712,7 +670,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
@@ -726,7 +684,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
@@ -753,7 +711,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
# verify that is not in quarantine
@@ -785,7 +743,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
upload_resource,
SMALL_PNG,
tok=self.admin_user_tok,
- expect_code=HTTPStatus.OK,
+ expect_code=200,
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -801,11 +759,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url % (action, self.media_id), b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["protect", "unprotect"])
@@ -822,11 +776,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_protect_media(self) -> None:
@@ -845,7 +795,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
@@ -859,7 +809,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
@@ -895,7 +845,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, b"{}")
self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
+ 401,
channel.code,
msg=channel.json_body,
)
@@ -914,11 +864,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -931,11 +877,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter before_ts must be a positive integer.",
@@ -948,11 +890,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter before_ts you provided is from the year 1970. "
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
index 8354250ec2..8f8abc21c7 100644
--- a/tests/rest/admin/test_registration_tokens.py
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -13,7 +13,6 @@
# limitations under the License.
import random
import string
-from http import HTTPStatus
from typing import Optional
from twisted.test.proto_helpers import MemoryReactor
@@ -74,11 +73,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
def test_create_no_auth(self) -> None:
"""Try to create a token without authentication."""
channel = self.make_request("POST", self.url + "/new", {})
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_create_requester_not_admin(self) -> None:
@@ -89,11 +84,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_create_using_defaults(self) -> None:
@@ -105,7 +96,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["token"]), 16)
self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"])
@@ -129,7 +120,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["token"], token)
self.assertEqual(channel.json_body["uses_allowed"], 1)
self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"])
@@ -150,7 +141,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["token"]), 16)
self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"])
@@ -168,11 +159,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_token_invalid_chars(self) -> None:
@@ -188,11 +175,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_token_already_exists(self) -> None:
@@ -207,7 +190,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
data,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel1.code, msg=channel1.json_body)
+ self.assertEqual(200, channel1.code, msg=channel1.json_body)
channel2 = self.make_request(
"POST",
@@ -215,7 +198,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
data,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel2.code, msg=channel2.json_body)
+ self.assertEqual(400, channel2.code, msg=channel2.json_body)
self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_unable_to_generate_token(self) -> None:
@@ -251,7 +234,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 0},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["uses_allowed"], 0)
# Should fail with negative integer
@@ -262,7 +245,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
self.assertEqual(
- HTTPStatus.BAD_REQUEST,
+ 400,
channel.code,
msg=channel.json_body,
)
@@ -275,11 +258,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 1.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_expiry_time(self) -> None:
@@ -291,11 +270,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": self.clock.time_msec() - 10000},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with float
@@ -305,11 +280,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": self.clock.time_msec() + 1000000.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_length(self) -> None:
@@ -321,7 +292,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 64},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["token"]), 64)
# Should fail with 0
@@ -331,11 +302,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 0},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with a negative integer
@@ -345,11 +312,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": -5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with a float
@@ -359,11 +322,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 8.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with 65
@@ -373,11 +332,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 65},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# UPDATING
@@ -389,11 +344,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.url + "/1234", # Token doesn't exist but that doesn't matter
{},
)
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_update_requester_not_admin(self) -> None:
@@ -404,11 +355,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_update_non_existent(self) -> None:
@@ -420,11 +367,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.NOT_FOUND,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_update_uses_allowed(self) -> None:
@@ -439,7 +382,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 1},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["uses_allowed"], 1)
self.assertIsNone(channel.json_body["expiry_time"])
@@ -450,7 +393,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 0},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["uses_allowed"], 0)
self.assertIsNone(channel.json_body["expiry_time"])
@@ -461,7 +404,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": None},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"])
@@ -472,11 +415,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 1.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with a negative integer
@@ -486,11 +425,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": -5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_update_expiry_time(self) -> None:
@@ -506,7 +441,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": new_expiry_time},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
self.assertIsNone(channel.json_body["uses_allowed"])
@@ -517,7 +452,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": None},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIsNone(channel.json_body["expiry_time"])
self.assertIsNone(channel.json_body["uses_allowed"])
@@ -529,11 +464,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": past_time},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail a float
@@ -543,11 +474,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": new_expiry_time + 0.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_update_both(self) -> None:
@@ -568,7 +495,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["uses_allowed"], 1)
self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
@@ -589,11 +516,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# DELETING
@@ -605,11 +528,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.url + "/1234", # Token doesn't exist but that doesn't matter
{},
)
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_delete_requester_not_admin(self) -> None:
@@ -620,11 +539,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_delete_non_existent(self) -> None:
@@ -636,11 +551,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.NOT_FOUND,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_delete(self) -> None:
@@ -655,7 +566,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# GETTING ONE
@@ -666,11 +577,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.url + "/1234", # Token doesn't exist but that doesn't matter
{},
)
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_get_requester_not_admin(self) -> None:
@@ -682,7 +589,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
self.assertEqual(
- HTTPStatus.FORBIDDEN,
+ 403,
channel.code,
msg=channel.json_body,
)
@@ -697,11 +604,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.NOT_FOUND,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_get(self) -> None:
@@ -716,7 +619,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["token"], token)
self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"])
@@ -728,11 +631,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
def test_list_no_auth(self) -> None:
"""Try to list tokens without authentication."""
channel = self.make_request("GET", self.url, {})
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_list_requester_not_admin(self) -> None:
@@ -743,11 +642,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_list_all(self) -> None:
@@ -762,7 +657,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["registration_tokens"]), 1)
token_info = channel.json_body["registration_tokens"][0]
self.assertEqual(token_info["token"], token)
@@ -780,11 +675,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
def _test_list_query_parameter(self, valid: str) -> None:
"""Helper used to test both valid=true and valid=false."""
@@ -816,7 +707,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["registration_tokens"]), 2)
token_info_1 = channel.json_body["registration_tokens"][0]
token_info_2 = channel.json_body["registration_tokens"][1]
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index ca6af9417b..9d71a97524 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import urllib.parse
-from http import HTTPStatus
from typing import List, Optional
from unittest.mock import Mock
@@ -21,7 +20,7 @@ from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RoomTypes
from synapse.api.errors import Codes
from synapse.handlers.pagination import PaginationHandler
from synapse.rest.client import directory, events, login, room
@@ -68,7 +67,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
def test_requester_is_no_admin(self) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
@@ -78,7 +77,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_room_does_not_exist(self) -> None:
@@ -94,11 +93,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def test_room_is_not_valid(self) -> None:
"""
- Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
+ Check that invalid room names, return an error 400.
"""
url = "/_synapse/admin/v1/rooms/%s" % "invalidroom"
@@ -109,7 +108,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom is not a legal room ID",
channel.json_body["error"],
@@ -127,7 +126,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("new_room_id", channel.json_body)
self.assertIn("kicked_users", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -145,7 +144,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"User must be our own: @not:exist.bla",
channel.json_body["error"],
@@ -163,7 +162,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_purge_is_not_bool(self) -> None:
@@ -178,7 +177,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_purge_room_and_block(self) -> None:
@@ -202,7 +201,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -233,7 +232,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -265,7 +264,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -296,7 +295,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
)
# The room is now blocked.
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self._is_blocked(room_id)
def test_shutdown_room_consent(self) -> None:
@@ -319,7 +318,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self.room_id,
body="foo",
tok=self.other_user_tok,
- expect_code=HTTPStatus.FORBIDDEN,
+ expect_code=403,
)
# Test that room is not purged
@@ -337,7 +336,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("new_room_id", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -366,7 +365,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
{"history_visibility": "world_readable"},
access_token=self.other_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Test that room is not purged
with self.assertRaises(AssertionError):
@@ -383,7 +382,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("new_room_id", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -398,7 +397,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self._has_no_members(self.room_id)
# Assert we can no longer peek into the room
- self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN)
+ self._assert_peek(self.room_id, expect_code=403)
def _is_blocked(self, room_id: str, expect: bool = True) -> None:
"""Assert that the room is blocked or not"""
@@ -494,7 +493,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
)
def test_requester_is_no_admin(self, method: str, url: str) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
@@ -504,7 +503,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_room_does_not_exist(self) -> None:
@@ -522,7 +521,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"]
@@ -533,7 +532,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, len(channel.json_body["results"]))
self.assertEqual("complete", channel.json_body["results"][0]["status"])
self.assertEqual(delete_id, channel.json_body["results"][0]["delete_id"])
@@ -546,7 +545,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
)
def test_room_is_not_valid(self, method: str, url: str) -> None:
"""
- Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
+ Check that invalid room names, return an error 400.
"""
channel = self.make_request(
@@ -556,7 +555,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom is not a legal room ID",
channel.json_body["error"],
@@ -574,7 +573,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"]
@@ -592,7 +591,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"User must be our own: @not:exist.bla",
channel.json_body["error"],
@@ -610,7 +609,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_purge_is_not_bool(self) -> None:
@@ -625,7 +624,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_delete_expired_status(self) -> None:
@@ -639,7 +638,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id1 = channel.json_body["delete_id"]
@@ -654,7 +653,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id2 = channel.json_body["delete_id"]
@@ -665,7 +664,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(2, len(channel.json_body["results"]))
self.assertEqual("complete", channel.json_body["results"][0]["status"])
self.assertEqual("complete", channel.json_body["results"][1]["status"])
@@ -682,7 +681,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, len(channel.json_body["results"]))
self.assertEqual("complete", channel.json_body["results"][0]["status"])
self.assertEqual(delete_id2, channel.json_body["results"][0]["delete_id"])
@@ -696,7 +695,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_delete_same_room_twice(self) -> None:
@@ -722,9 +721,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST, second_channel.code, msg=second_channel.json_body
- )
+ self.assertEqual(400, second_channel.code, msg=second_channel.json_body)
self.assertEqual(Codes.UNKNOWN, second_channel.json_body["errcode"])
self.assertEqual(
f"History purge already in progress for {self.room_id}",
@@ -733,7 +730,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
# get result of first call
first_channel.await_result()
- self.assertEqual(HTTPStatus.OK, first_channel.code, msg=first_channel.json_body)
+ self.assertEqual(200, first_channel.code, msg=first_channel.json_body)
self.assertIn("delete_id", first_channel.json_body)
# check status after finish the task
@@ -764,7 +761,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"]
@@ -795,7 +792,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"]
@@ -827,7 +824,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"]
@@ -858,7 +855,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.room_id,
body="foo",
tok=self.other_user_tok,
- expect_code=HTTPStatus.FORBIDDEN,
+ expect_code=403,
)
# Test that room is not purged
@@ -876,7 +873,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"]
@@ -887,7 +884,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.url_status_by_room_id,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, len(channel.json_body["results"]))
# Test that member has moved to new room
@@ -914,7 +911,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
content={"history_visibility": "world_readable"},
access_token=self.other_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Test that room is not purged
with self.assertRaises(AssertionError):
@@ -931,7 +928,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"]
@@ -942,7 +939,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.url_status_by_room_id,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, len(channel.json_body["results"]))
# Test that member has moved to new room
@@ -955,7 +952,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self._has_no_members(self.room_id)
# Assert we can no longer peek into the room
- self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN)
+ self._assert_peek(self.room_id, expect_code=403)
def _is_blocked(self, room_id: str, expect: bool = True) -> None:
"""Assert that the room is blocked or not"""
@@ -1026,9 +1023,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.url_status_by_room_id,
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.OK, channel_room_id.code, msg=channel_room_id.json_body
- )
+ self.assertEqual(200, channel_room_id.code, msg=channel_room_id.json_body)
self.assertEqual(1, len(channel_room_id.json_body["results"]))
self.assertEqual(
delete_id, channel_room_id.json_body["results"][0]["delete_id"]
@@ -1041,7 +1036,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
self.assertEqual(
- HTTPStatus.OK,
+ 200,
channel_delete_id.code,
msg=channel_delete_id.json_body,
)
@@ -1085,7 +1080,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
room_ids = []
for _ in range(total_rooms):
room_id = self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok
+ self.admin_user,
+ tok=self.admin_user_tok,
+ is_public=True,
)
room_ids.append(room_id)
@@ -1100,7 +1097,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
# Check request completed successfully
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that response json body contains a "rooms" key
self.assertTrue(
@@ -1124,12 +1121,14 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertIn("version", r)
self.assertIn("creator", r)
self.assertIn("encryption", r)
- self.assertIn("federatable", r)
- self.assertIn("public", r)
+ self.assertIs(r["federatable"], True)
+ self.assertIs(r["public"], True)
self.assertIn("join_rules", r)
self.assertIn("guest_access", r)
self.assertIn("history_visibility", r)
self.assertIn("state_events", r)
+ self.assertIn("room_type", r)
+ self.assertIsNone(r["room_type"])
# Check that the correct number of total rooms was returned
self.assertEqual(channel.json_body["total_rooms"], total_rooms)
@@ -1184,7 +1183,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue("rooms" in channel.json_body)
for r in channel.json_body["rooms"]:
@@ -1224,12 +1223,16 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def test_correct_room_attributes(self) -> None:
"""Test the correct attributes for a room are returned"""
# Create a test room
- room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id = self.helper.create_room_as(
+ self.admin_user,
+ tok=self.admin_user_tok,
+ extra_content={"creation_content": {"type": RoomTypes.SPACE}},
+ )
test_alias = "#test:test"
test_room_name = "something"
@@ -1247,7 +1250,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
{"room_id": room_id},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Set this new alias as the canonical alias for this room
self.helper.send_state(
@@ -1279,7 +1282,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that rooms were returned
self.assertTrue("rooms" in channel.json_body)
@@ -1306,6 +1309,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(room_id, r["room_id"])
self.assertEqual(test_room_name, r["name"])
self.assertEqual(test_alias, r["canonical_alias"])
+ self.assertEqual(RoomTypes.SPACE, r["room_type"])
def test_room_list_sort_order(self) -> None:
"""Test room list sort ordering. alphabetical name versus number of members,
@@ -1334,7 +1338,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that rooms were returned
self.assertTrue("rooms" in channel.json_body)
@@ -1480,7 +1484,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
def _search_test(
expected_room_id: Optional[str],
search_term: str,
- expected_http_code: int = HTTPStatus.OK,
+ expected_http_code: int = 200,
) -> None:
"""Search for a room and check that the returned room's id is a match
@@ -1498,7 +1502,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
- if expected_http_code != HTTPStatus.OK:
+ if expected_http_code != 200:
return
# Check that rooms were returned
@@ -1541,7 +1545,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
_search_test(None, "foo")
_search_test(None, "bar")
- _search_test(None, "", expected_http_code=HTTPStatus.BAD_REQUEST)
+ _search_test(None, "", expected_http_code=400)
# Test that the whole room id returns the room
_search_test(room_id_1, room_id_1)
@@ -1578,15 +1582,19 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
- self.assertEqual(room_id, channel.json_body.get("rooms")[0].get("room_id"))
- self.assertEqual("ж", channel.json_body.get("rooms")[0].get("name"))
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(room_id, channel.json_body["rooms"][0].get("room_id"))
+ self.assertEqual("ж", channel.json_body["rooms"][0].get("name"))
def test_single_room(self) -> None:
"""Test that a single room can be requested correctly"""
# Create two test rooms
- room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_1 = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok, is_public=True
+ )
+ room_id_2 = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok, is_public=False
+ )
room_name_1 = "something"
room_name_2 = "else"
@@ -1611,7 +1619,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("room_id", channel.json_body)
self.assertIn("name", channel.json_body)
@@ -1630,8 +1638,12 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertIn("guest_access", channel.json_body)
self.assertIn("history_visibility", channel.json_body)
self.assertIn("state_events", channel.json_body)
+ self.assertIn("room_type", channel.json_body)
+ self.assertIn("forgotten", channel.json_body)
self.assertEqual(room_id_1, channel.json_body["room_id"])
+ self.assertIs(True, channel.json_body["federatable"])
+ self.assertIs(True, channel.json_body["public"])
def test_single_room_devices(self) -> None:
"""Test that `joined_local_devices` can be requested correctly"""
@@ -1643,7 +1655,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["joined_local_devices"])
# Have another user join the room
@@ -1657,7 +1669,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(2, channel.json_body["joined_local_devices"])
# leave room
@@ -1669,7 +1681,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["joined_local_devices"])
def test_room_members(self) -> None:
@@ -1700,7 +1712,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertCountEqual(
["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"]
@@ -1713,7 +1725,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertCountEqual(
["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"]
@@ -1731,7 +1743,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("state", channel.json_body)
# testing that the state events match is painful and not done here. We assume that
# the create_room already does the right thing, so no need to verify that we got
@@ -1748,7 +1760,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
{"room_id": room_id},
access_token=admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Set this new alias as the canonical alias for this room
self.helper.send_state(
@@ -1765,6 +1777,21 @@ class RoomTestCase(unittest.HomeserverTestCase):
tok=admin_user_tok,
)
+ def test_get_joined_members_after_leave_room(self) -> None:
+ """Test that requesting room members after leaving the room raises a 403 error."""
+
+ # create the room
+ user = self.register_user("foo", "pass")
+ user_tok = self.login("foo", "pass")
+ room_id = self.helper.create_room_as(user, tok=user_tok)
+ self.helper.leave(room_id, user, tok=user_tok)
+
+ # delete the rooms and get joined roomed membership
+ url = f"/_matrix/client/r0/rooms/{room_id}/joined_members"
+ channel = self.make_request("GET", url.encode("ascii"), access_token=user_tok)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
@@ -1791,7 +1818,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
def test_requester_is_no_admin(self) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
@@ -1801,7 +1828,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.second_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -1816,12 +1843,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
def test_local_user_does_not_exist(self) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
channel = self.make_request(
@@ -1831,7 +1858,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_remote_user(self) -> None:
@@ -1846,7 +1873,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"This endpoint can only be used with local users",
channel.json_body["error"],
@@ -1854,7 +1881,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
def test_room_does_not_exist(self) -> None:
"""
- Check that unknown rooms/server return error HTTPStatus.NOT_FOUND.
+ Check that unknown rooms/server return error 404.
"""
url = "/_synapse/admin/v1/join/!unknown:test"
@@ -1865,12 +1892,15 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
- self.assertEqual("No known servers", channel.json_body["error"])
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ "Can't join remote room because no servers that are in the room have been provided.",
+ channel.json_body["error"],
+ )
def test_room_is_not_valid(self) -> None:
"""
- Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
+ Check that invalid room names, return an error 400.
"""
url = "/_synapse/admin/v1/join/invalidroom"
@@ -1881,7 +1911,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom was not legal room ID or room alias",
channel.json_body["error"],
@@ -1899,7 +1929,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.public_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
@@ -1909,7 +1939,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.second_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
def test_join_private_room_if_not_member(self) -> None:
@@ -1929,7 +1959,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_join_private_room_if_member(self) -> None:
@@ -1957,7 +1987,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
# Join user to room.
@@ -1970,7 +2000,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
content={"user_id": self.second_user_id},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
@@ -1980,7 +2010,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.second_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
def test_join_private_room_if_owner(self) -> None:
@@ -2000,7 +2030,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
@@ -2010,7 +2040,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.second_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
def test_context_as_non_admin(self) -> None:
@@ -2044,7 +2074,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
% (room_id, events[midway]["event_id"]),
access_token=tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_context_as_admin(self) -> None:
@@ -2074,7 +2104,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
% (room_id, events[midway]["event_id"]),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
channel.json_body["event"]["event_id"], events[midway]["event_id"]
)
@@ -2133,7 +2163,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room and ban a user.
self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok)
@@ -2160,7 +2190,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room (we should have received an
# invite) and can ban a user.
@@ -2186,7 +2216,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room and ban a user.
self.helper.join(room_id, self.second_user_id, tok=self.second_tok)
@@ -2220,11 +2250,11 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- # We expect this to fail with a HTTPStatus.BAD_REQUEST as there are no room admins.
+ # We expect this to fail with a 400 as there are no room admins.
#
# (Note we assert the error message to ensure that it's not denied for
# some other reason)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
channel.json_body["error"],
"No local admin user in room with power to update power levels.",
@@ -2254,7 +2284,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
@parameterized.expand([("PUT",), ("GET",)])
def test_requester_is_no_admin(self, method: str) -> None:
- """If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned."""
+ """If the user is not a server admin, an error 403 is returned."""
channel = self.make_request(
method,
@@ -2263,12 +2293,12 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand([("PUT",), ("GET",)])
def test_room_is_not_valid(self, method: str) -> None:
- """Check that invalid room names, return an error HTTPStatus.BAD_REQUEST."""
+ """Check that invalid room names, return an error 400."""
channel = self.make_request(
method,
@@ -2277,7 +2307,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom is not a legal room ID",
channel.json_body["error"],
@@ -2294,7 +2324,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
# `block` is not set
@@ -2305,7 +2335,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
# no content is send
@@ -2315,7 +2345,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"])
def test_block_room(self) -> None:
@@ -2329,7 +2359,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
content={"block": True},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["block"])
self._is_blocked(room_id, expect=True)
@@ -2353,7 +2383,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
content={"block": True},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["block"])
self._is_blocked(self.room_id, expect=True)
@@ -2369,7 +2399,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
content={"block": False},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body["block"])
self._is_blocked(room_id, expect=False)
@@ -2393,7 +2423,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
content={"block": False},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body["block"])
self._is_blocked(self.room_id, expect=False)
@@ -2408,7 +2438,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
self.url % room_id,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["block"])
self.assertEqual(self.other_user, channel.json_body["user_id"])
@@ -2432,7 +2462,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
self.url % room_id,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body["block"])
self.assertNotIn("user_id", channel.json_body)
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index dbcba2663c..a2f347f666 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
from typing import List
from twisted.test.proto_helpers import MemoryReactor
@@ -57,7 +56,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url)
self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
+ 401,
channel.code,
msg=channel.json_body,
)
@@ -72,7 +71,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- HTTPStatus.FORBIDDEN,
+ 403,
channel.code,
msg=channel.json_body,
)
@@ -80,7 +79,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
@override_config({"server_notices": {"system_mxid_localpart": "notices"}})
def test_user_does_not_exist(self) -> None:
- """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND"""
+ """Tests that a lookup for a user that does not exist returns a 404"""
channel = self.make_request(
"POST",
self.url,
@@ -88,13 +87,13 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
content={"user_id": "@unknown_person:test", "content": ""},
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@override_config({"server_notices": {"system_mxid_localpart": "notices"}})
def test_user_is_not_local(self) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
channel = self.make_request(
"POST",
@@ -106,7 +105,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"Server notices can only be sent to local users", channel.json_body["error"]
)
@@ -122,7 +121,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"])
# no content
@@ -133,7 +132,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
content={"user_id": self.other_user},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
# no body
@@ -144,7 +143,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
content={"user_id": self.other_user, "content": ""},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual("'body' not in content", channel.json_body["error"])
@@ -156,10 +155,66 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
content={"user_id": self.other_user, "content": {"body": ""}},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual("'msgtype' not in content", channel.json_body["error"])
+ @override_config(
+ {
+ "server_notices": {
+ "system_mxid_localpart": "notices",
+ "system_mxid_avatar_url": "somthingwrong",
+ },
+ "max_avatar_size": "10M",
+ }
+ )
+ def test_invalid_avatar_url(self) -> None:
+ """If avatar url in homeserver.yaml is invalid and
+ "check avatar size and mime type" is set, an error is returned.
+ TODO: Should be checked when reading the configuration."""
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg"},
+ },
+ )
+
+ self.assertEqual(500, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+ @override_config(
+ {
+ "server_notices": {
+ "system_mxid_localpart": "notices",
+ "system_mxid_display_name": "test display name",
+ "system_mxid_avatar_url": None,
+ },
+ "max_avatar_size": "10M",
+ }
+ )
+ def test_displayname_is_set_avatar_is_none(self) -> None:
+ """
+ Tests that sending a server notices is successfully,
+ if a display_name is set, avatar_url is `None` and
+ "check avatar size and mime type" is set.
+ """
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ self._check_invite_and_join_status(self.other_user, 1, 0)
+
def test_server_notice_disabled(self) -> None:
"""Tests that server returns error if server notice is disabled"""
channel = self.make_request(
@@ -172,7 +227,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual(
"Server notices are not enabled on this server", channel.json_body["error"]
@@ -197,7 +252,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg one"},
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -226,7 +281,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg two"},
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# user has no new invites or memberships
self._check_invite_and_join_status(self.other_user, 0, 1)
@@ -260,7 +315,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg one"},
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -301,7 +356,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg two"},
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -341,7 +396,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg one"},
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -388,7 +443,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg two"},
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -538,7 +593,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET", "/_matrix/client/r0/sync", access_token=token
)
- self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertEqual(channel.code, 200)
# Get the messages
room = channel.json_body["rooms"]["join"][room_id]
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index 7cb8ec57ba..b60f16b914 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
from typing import List, Optional
from twisted.test.proto_helpers import MemoryReactor
@@ -51,16 +50,12 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
"GET",
@@ -69,11 +64,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -87,11 +78,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -101,11 +88,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative limit
@@ -115,11 +98,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from_ts
@@ -129,11 +108,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative until_ts
@@ -143,11 +118,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# until_ts smaller from_ts
@@ -157,11 +128,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# empty search term
@@ -171,11 +138,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
@@ -185,11 +148,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_limit(self) -> None:
@@ -204,7 +163,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["users"]), 5)
self.assertEqual(channel.json_body["next_token"], 5)
@@ -222,7 +181,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["users"]), 15)
self.assertNotIn("next_token", channel.json_body)
@@ -240,7 +199,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(channel.json_body["next_token"], 15)
self.assertEqual(len(channel.json_body["users"]), 10)
@@ -262,7 +221,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
@@ -275,7 +234,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
@@ -288,7 +247,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 19)
self.assertEqual(channel.json_body["next_token"], 19)
@@ -301,7 +260,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 1)
self.assertNotIn("next_token", channel.json_body)
@@ -318,7 +277,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["users"]))
@@ -415,7 +374,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
# filter media starting at `ts1` after creating first media
@@ -425,7 +384,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?from_ts=%s" % (ts1,),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 0)
self._create_media(self.other_user_tok, 3)
@@ -440,7 +399,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
# filter media until `ts2` and earlier
@@ -449,7 +408,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?until_ts=%s" % (ts2,),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["users"][0]["media_count"], 6)
def test_search_term(self) -> None:
@@ -461,7 +420,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
# filter user 1 and 10-19 by `user_id`
@@ -470,7 +429,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?search_term=foo_user_1",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 11)
# filter on this user in `displayname`
@@ -479,7 +438,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?search_term=bar_user_10",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["users"][0]["displayname"], "bar_user_10")
self.assertEqual(channel.json_body["total"], 1)
@@ -489,7 +448,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?search_term=foobar",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 0)
def _create_users_with_media(self, number_users: int, media_per_user: int) -> None:
@@ -515,7 +474,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
for _ in range(number_media):
# Upload some media into the room
self.helper.upload_media(
- upload_resource, SMALL_PNG, tok=user_token, expect_code=HTTPStatus.OK
+ upload_resource, SMALL_PNG, tok=user_token, expect_code=200
)
def _check_fields(self, content: List[JsonDict]) -> None:
@@ -549,7 +508,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_user_list))
returned_order = [row["user_id"] for row in channel.json_body["users"]]
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 0d44102237..1afd082707 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1,4 +1,4 @@
-# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2018-2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,7 +17,6 @@ import hmac
import os
import urllib.parse
from binascii import unhexlify
-from http import HTTPStatus
from typing import List, Optional
from unittest.mock import Mock, patch
@@ -79,7 +78,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"Shared secret registration is not enabled", channel.json_body["error"]
)
@@ -111,7 +110,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
body = {"nonce": nonce}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("username must be specified", channel.json_body["error"])
# 61 seconds
@@ -119,7 +118,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("unrecognised nonce", channel.json_body["error"])
def test_register_incorrect_nonce(self) -> None:
@@ -142,7 +141,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("HMAC incorrect", channel.json_body["error"])
def test_register_correct_nonce(self) -> None:
@@ -169,7 +168,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
def test_nonce_reuse(self) -> None:
@@ -192,13 +191,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
# Now, try and reuse it
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("unrecognised nonce", channel.json_body["error"])
def test_missing_parts(self) -> None:
@@ -219,7 +218,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be an empty body present
channel = self.make_request("POST", self.url, {})
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("nonce must be specified", channel.json_body["error"])
#
@@ -229,28 +228,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be present
channel = self.make_request("POST", self.url, {"nonce": nonce()})
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("username must be specified", channel.json_body["error"])
# Must be a string
body = {"nonce": nonce(), "username": 1234}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes
body = {"nonce": nonce(), "username": "abcd\u0000"}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes
body = {"nonce": nonce(), "username": "a" * 1000}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
#
@@ -261,28 +260,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
body = {"nonce": nonce(), "username": "a"}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("password must be specified", channel.json_body["error"])
# Must be a string
body = {"nonce": nonce(), "username": "a", "password": 1234}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
# Must not have null bytes
body = {"nonce": nonce(), "username": "a", "password": "abcd\u0000"}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
# Super long
body = {"nonce": nonce(), "username": "a", "password": "A" * 1000}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
#
@@ -298,7 +297,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid user type", channel.json_body["error"])
def test_displayname(self) -> None:
@@ -323,11 +322,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob1:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob1:test/displayname")
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("bob1", channel.json_body["displayname"])
# displayname is None
@@ -347,11 +346,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob2:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob2:test/displayname")
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("bob2", channel.json_body["displayname"])
# displayname is empty
@@ -371,11 +370,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob3:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob3:test/displayname")
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
# set displayname
channel = self.make_request("GET", self.url)
@@ -394,11 +393,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob4:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob4:test/displayname")
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("Bob's Name", channel.json_body["displayname"])
@override_config(
@@ -442,7 +441,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -466,7 +465,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
@@ -478,7 +477,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", self.url, access_token=other_user_token)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_all_users(self) -> None:
@@ -494,7 +493,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(3, len(channel.json_body["users"]))
self.assertEqual(3, channel.json_body["total"])
@@ -508,7 +507,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
expected_user_id: Optional[str],
search_term: str,
search_field: Optional[str] = "name",
- expected_http_code: Optional[int] = HTTPStatus.OK,
+ expected_http_code: Optional[int] = 200,
) -> None:
"""Search for a user and check that the returned user's id is a match
@@ -530,7 +529,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
- if expected_http_code != HTTPStatus.OK:
+ if expected_http_code != 200:
return
# Check that users were returned
@@ -591,7 +590,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -601,7 +600,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid guests
@@ -611,7 +610,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid deactivated
@@ -621,7 +620,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# unkown order_by
@@ -631,7 +630,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
@@ -641,7 +640,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_limit(self) -> None:
@@ -659,7 +658,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 5)
self.assertEqual(channel.json_body["next_token"], "5")
@@ -680,7 +679,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 15)
self.assertNotIn("next_token", channel.json_body)
@@ -701,7 +700,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(channel.json_body["next_token"], "15")
self.assertEqual(len(channel.json_body["users"]), 10)
@@ -724,7 +723,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
@@ -737,7 +736,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
@@ -750,7 +749,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 19)
self.assertEqual(channel.json_body["next_token"], "19")
@@ -764,7 +763,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 1)
self.assertNotIn("next_token", channel.json_body)
@@ -867,7 +866,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_user_list))
returned_order = [row["name"] for row in channel.json_body["users"]]
@@ -905,6 +904,96 @@ class UsersListTestCase(unittest.HomeserverTestCase):
)
+class UserDevicesTestCase(unittest.HomeserverTestCase):
+ """
+ Tests user device management-related Admin APIs.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ # Set up an Admin user to query the Admin API with.
+ self.admin_user_id = self.register_user("admin", "pass", admin=True)
+ self.admin_user_token = self.login("admin", "pass")
+
+ # Set up a test user to query the devices of.
+ self.other_user_device_id = "TESTDEVICEID"
+ self.other_user_device_display_name = "My Test Device"
+ self.other_user_client_ip = "1.2.3.4"
+ self.other_user_user_agent = "EquestriaTechnology/123.0"
+
+ self.other_user_id = self.register_user("user", "pass", displayname="User1")
+ self.other_user_token = self.login(
+ "user",
+ "pass",
+ device_id=self.other_user_device_id,
+ additional_request_fields={
+ "initial_device_display_name": self.other_user_device_display_name,
+ },
+ )
+
+ # Have the "other user" make a request so that the "last_seen_*" fields are
+ # populated in the tests below.
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/sync",
+ access_token=self.other_user_token,
+ client_ip=self.other_user_client_ip,
+ custom_headers=[
+ ("User-Agent", self.other_user_user_agent),
+ ],
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ def test_list_user_devices(self) -> None:
+ """
+ Tests that a user's devices and attributes are listed correctly via the Admin API.
+ """
+ # Request all devices of "other user"
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v2/users/{self.other_user_id}/devices",
+ access_token=self.admin_user_token,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Double-check we got the single device expected
+ user_devices = channel.json_body["devices"]
+ self.assertEqual(len(user_devices), 1)
+ self.assertEqual(channel.json_body["total"], 1)
+
+ # Check that all the attributes of the device reported are as expected.
+ self._validate_attributes_of_device_response(user_devices[0])
+
+ # Request just a single device for "other user" by its ID
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v2/users/{self.other_user_id}/devices/"
+ f"{self.other_user_device_id}",
+ access_token=self.admin_user_token,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Check that all the attributes of the device reported are as expected.
+ self._validate_attributes_of_device_response(channel.json_body)
+
+ def _validate_attributes_of_device_response(self, response: JsonDict) -> None:
+ # Check that all device expected attributes are present
+ self.assertEqual(response["user_id"], self.other_user_id)
+ self.assertEqual(response["device_id"], self.other_user_device_id)
+ self.assertEqual(response["display_name"], self.other_user_device_display_name)
+ self.assertEqual(response["last_seen_ip"], self.other_user_client_ip)
+ self.assertEqual(response["last_seen_user_agent"], self.other_user_user_agent)
+ self.assertIsInstance(response["last_seen_ts"], int)
+ self.assertGreater(response["last_seen_ts"], 0)
+
+
class DeactivateAccountTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -941,7 +1030,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_not_admin(self) -> None:
@@ -952,7 +1041,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", url, access_token=self.other_user_token)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
channel = self.make_request(
@@ -962,12 +1051,12 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content=b"{}",
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
def test_user_does_not_exist(self) -> None:
"""
- Tests that deactivation for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that deactivation for a user that does not exist returns a 404
"""
channel = self.make_request(
@@ -976,7 +1065,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_erase_is_not_bool(self) -> None:
@@ -991,18 +1080,18 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_user_is_not_local(self) -> None:
"""
- Tests that deactivation for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that deactivation for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain"
channel = self.make_request("POST", url, access_token=self.admin_user_tok)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only deactivate local users", channel.json_body["error"])
def test_deactivate_user_erase_true(self) -> None:
@@ -1017,7 +1106,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
@@ -1032,7 +1121,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content={"erase": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Get user
channel = self.make_request(
@@ -1041,7 +1130,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
@@ -1066,7 +1155,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"erase": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self._is_erased("@user:test", True)
def test_deactivate_user_erase_false(self) -> None:
@@ -1081,7 +1170,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
@@ -1096,7 +1185,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content={"erase": False},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Get user
channel = self.make_request(
@@ -1105,7 +1194,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
@@ -1135,7 +1224,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
@@ -1150,7 +1239,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content={"erase": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Get user
channel = self.make_request(
@@ -1159,7 +1248,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
@@ -1220,7 +1309,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
channel = self.make_request(
@@ -1230,12 +1319,12 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=b"{}",
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
def test_user_does_not_exist(self) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
channel = self.make_request(
@@ -1244,7 +1333,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -1259,7 +1348,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"admin": "not_bool"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
# deactivated not bool
@@ -1269,7 +1358,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": "not_bool"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# password not str
@@ -1279,7 +1368,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"password": True},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# password not length
@@ -1289,7 +1378,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"password": "x" * 513},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# user_type not valid
@@ -1299,7 +1388,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"user_type": "new type"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# external_ids not valid
@@ -1311,7 +1400,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"external_ids": {"auth_provider": "prov", "wrong_external_id": "id"}
},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
channel = self.make_request(
@@ -1320,7 +1409,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"external_ids": {"external_id": "id"}},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
# threepids not valid
@@ -1330,7 +1419,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"threepids": {"medium": "email", "wrong_address": "id"}},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
channel = self.make_request(
@@ -1339,7 +1428,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"threepids": {"address": "value"}},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
def test_get_user(self) -> None:
@@ -1352,7 +1441,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("User", channel.json_body["displayname"])
self._check_fields(channel.json_body)
@@ -1395,7 +1484,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1458,7 +1547,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1486,9 +1575,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# before limit of monthly active users is reached
channel = self.make_request("GET", "/sync", access_token=self.admin_user_tok)
- if channel.code != HTTPStatus.OK:
+ if channel.code != 200:
raise HttpResponseException(
- channel.code, channel.result["reason"], channel.json_body
+ channel.code, channel.result["reason"], channel.result["body"]
)
# Set monthly active users to the limit
@@ -1636,6 +1725,41 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(len(pushers), 0)
+ @override_config(
+ {
+ "email": {
+ "enable_notifs": True,
+ "notif_for_new_users": True,
+ "notif_from": "test@example.com",
+ },
+ "public_baseurl": "https://example.com",
+ }
+ )
+ def test_create_user_email_notif_for_new_users_with_msisdn_threepid(self) -> None:
+ """
+ Check that a new regular user is created successfully when they have a msisdn
+ threepid and email notif_for_new_users is set to True.
+ """
+ url = self.url_prefix % "@bob:test"
+
+ # Create user
+ body = {
+ "password": "abc123",
+ "threepids": [{"medium": "msisdn", "address": "1234567890"}],
+ }
+
+ channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content=body,
+ )
+
+ self.assertEqual(201, channel.code, msg=channel.json_body)
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual("msisdn", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual("1234567890", channel.json_body["threepids"][0]["address"])
+
def test_set_password(self) -> None:
"""
Test setting a new password for another user.
@@ -1649,7 +1773,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "hahaha"},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self._check_fields(channel.json_body)
def test_set_displayname(self) -> None:
@@ -1665,7 +1789,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"displayname": "foobar"},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"])
@@ -1676,7 +1800,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"])
@@ -1698,7 +1822,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
# result does not always have the same sort order, therefore it becomes sorted
@@ -1724,7 +1848,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1740,7 +1864,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1756,7 +1880,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"threepids": []},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self._check_fields(channel.json_body)
@@ -1783,7 +1907,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1802,7 +1926,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1824,7 +1948,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
# other user has this two threepids
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
# result does not always have the same sort order, therefore it becomes sorted
@@ -1843,7 +1967,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
url_first_user,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self._check_fields(channel.json_body)
@@ -1872,7 +1996,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["external_ids"]))
# result does not always have the same sort order, therefore it becomes sorted
@@ -1904,7 +2028,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -1923,7 +2047,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -1942,7 +2066,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"external_ids": []},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(0, len(channel.json_body["external_ids"]))
@@ -1971,7 +2095,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -1997,7 +2121,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -2029,7 +2153,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
# must fail
- self.assertEqual(HTTPStatus.CONFLICT, channel.code, msg=channel.json_body)
+ self.assertEqual(409, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual("External id is already in use.", channel.json_body["error"])
@@ -2040,7 +2164,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -2058,7 +2182,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -2089,7 +2213,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
@@ -2104,7 +2228,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"deactivated": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
@@ -2123,7 +2247,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
@@ -2153,7 +2277,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"deactivated": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
@@ -2169,7 +2293,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"displayname": "Foobar"},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
self.assertEqual("Foobar", channel.json_body["displayname"])
@@ -2193,7 +2317,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
# Reactivate the user.
channel = self.make_request(
@@ -2202,7 +2326,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self._is_erased("@user:test", False)
@@ -2226,7 +2350,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"},
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Reactivate the user without a password.
@@ -2236,7 +2360,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self._is_erased("@user:test", False)
@@ -2260,7 +2384,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"},
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Reactivate the user without a password.
@@ -2270,7 +2394,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self._is_erased("@user:test", False)
@@ -2291,7 +2415,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"admin": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["admin"])
@@ -2302,7 +2426,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["admin"])
@@ -2319,7 +2443,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"user_type": UserTypes.SUPPORT},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
@@ -2330,7 +2454,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
@@ -2342,7 +2466,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"user_type": None},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertIsNone(channel.json_body["user_type"])
@@ -2353,7 +2477,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertIsNone(channel.json_body["user_type"])
@@ -2383,7 +2507,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"])
self.assertEqual(0, channel.json_body["deactivated"])
@@ -2396,7 +2520,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "abc123", "deactivated": "false"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
# Check user is not deactivated
channel = self.make_request(
@@ -2405,7 +2529,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"])
@@ -2430,7 +2554,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["deactivated"])
self._is_erased(user_id, False)
d = self.store.mark_user_erased(user_id)
@@ -2485,7 +2609,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
@@ -2500,7 +2624,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self) -> None:
@@ -2514,7 +2638,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["joined_rooms"]))
@@ -2530,7 +2654,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["joined_rooms"]))
@@ -2546,7 +2670,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["joined_rooms"]))
@@ -2567,7 +2691,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_rooms, channel.json_body["total"])
self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
@@ -2614,7 +2738,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"])
@@ -2643,7 +2767,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
@@ -2658,12 +2782,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/pushers"
channel = self.make_request(
@@ -2672,12 +2796,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers"
@@ -2687,7 +2811,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only look up local users", channel.json_body["error"])
def test_get_pushers(self) -> None:
@@ -2702,7 +2826,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
# Register the pusher
@@ -2734,7 +2858,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
for p in channel.json_body["pushers"]:
@@ -2773,7 +2897,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
"""Try to list media of an user without authentication."""
channel = self.make_request(method, self.url, {})
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "DELETE"])
@@ -2787,12 +2911,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "DELETE"])
def test_user_does_not_exist(self, method: str) -> None:
- """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND"""
+ """Tests that a lookup for a user that does not exist returns a 404"""
url = "/_synapse/admin/v1/users/@unknown_person:test/media"
channel = self.make_request(
method,
@@ -2800,12 +2924,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@parameterized.expand(["GET", "DELETE"])
def test_user_is_not_local(self, method: str) -> None:
- """Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST"""
+ """Tests that a lookup for a user that is not a local returns a 400"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media"
channel = self.make_request(
@@ -2814,7 +2938,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only look up local users", channel.json_body["error"])
def test_limit_GET(self) -> None:
@@ -2830,7 +2954,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 5)
self.assertEqual(channel.json_body["next_token"], 5)
@@ -2849,7 +2973,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 5)
self.assertEqual(len(channel.json_body["deleted_media"]), 5)
@@ -2866,7 +2990,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 15)
self.assertNotIn("next_token", channel.json_body)
@@ -2885,7 +3009,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 15)
self.assertEqual(len(channel.json_body["deleted_media"]), 15)
@@ -2902,7 +3026,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(channel.json_body["next_token"], 15)
self.assertEqual(len(channel.json_body["media"]), 10)
@@ -2921,7 +3045,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["deleted_media"]), 10)
@@ -2935,7 +3059,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
@@ -2945,7 +3069,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative limit
@@ -2955,7 +3079,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -2965,7 +3089,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_next_token(self) -> None:
@@ -2988,7 +3112,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), number_media)
self.assertNotIn("next_token", channel.json_body)
@@ -3001,7 +3125,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), number_media)
self.assertNotIn("next_token", channel.json_body)
@@ -3014,7 +3138,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 19)
self.assertEqual(channel.json_body["next_token"], 19)
@@ -3028,7 +3152,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 1)
self.assertNotIn("next_token", channel.json_body)
@@ -3045,7 +3169,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["media"]))
@@ -3060,7 +3184,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["deleted_media"]))
@@ -3077,7 +3201,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_media, channel.json_body["total"])
self.assertEqual(number_media, len(channel.json_body["media"]))
self.assertNotIn("next_token", channel.json_body)
@@ -3103,7 +3227,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_media, channel.json_body["total"])
self.assertEqual(number_media, len(channel.json_body["deleted_media"]))
self.assertCountEqual(channel.json_body["deleted_media"], media_ids)
@@ -3248,7 +3372,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource, image_data, user_token, filename, expect_code=HTTPStatus.OK
+ upload_resource, image_data, user_token, filename, expect_code=200
)
# Extract media ID from the response
@@ -3266,10 +3390,10 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- HTTPStatus.OK,
+ 200,
channel.code,
msg=(
- f"Expected to receive a HTTPStatus.OK on accessing media: {server_and_media_id}"
+ f"Expected to receive a 200 on accessing media: {server_and_media_id}"
),
)
@@ -3315,7 +3439,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_media_list))
returned_order = [row["media_id"] for row in channel.json_body["media"]]
@@ -3351,14 +3475,14 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST", self.url, b"{}", access_token=self.admin_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
return channel.json_body["access_token"]
def test_no_auth(self) -> None:
"""Try to login as a user without authentication."""
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_not_admin(self) -> None:
@@ -3367,7 +3491,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
"POST", self.url, b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
def test_send_event(self) -> None:
"""Test that sending event as a user works."""
@@ -3392,7 +3516,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# We should only see the one device (from the login in `prepare`)
self.assertEqual(len(channel.json_body["devices"]), 1)
@@ -3404,21 +3528,21 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Test that we can successfully make a request
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Logout with the puppet token
channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# The puppet token should no longer work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
# .. but the real user's tokens should still work
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def test_user_logout_all(self) -> None:
"""Tests that the target user calling `/logout/all` does *not* expire
@@ -3429,23 +3553,23 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Test that we can successfully make a request
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Logout all with the real user token
channel = self.make_request(
"POST", "logout/all", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# The puppet token should still work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# .. but the real user's tokens shouldn't
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
def test_admin_logout_all(self) -> None:
"""Tests that the admin user calling `/logout/all` does expire the
@@ -3456,23 +3580,23 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Test that we can successfully make a request
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Logout all with the admin user token
channel = self.make_request(
"POST", "logout/all", b"{}", access_token=self.admin_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# The puppet token should no longer work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
# .. but the real user's tokens should still work
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
@unittest.override_config(
{
@@ -3503,7 +3627,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
room_id,
"com.example.test",
tok=self.other_user_tok,
- expect_code=HTTPStatus.FORBIDDEN,
+ expect_code=403,
)
# Login in as the user
@@ -3524,7 +3648,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
room_id,
user=self.other_user,
tok=self.other_user_tok,
- expect_code=HTTPStatus.FORBIDDEN,
+ expect_code=403,
)
# Logging in as the other user and joining a room should work, even
@@ -3559,7 +3683,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
Try to get information of an user without authentication.
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_not_admin(self) -> None:
@@ -3574,12 +3698,12 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.url,
access_token=other_user2_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_is_not_local(self) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = self.url_prefix % "@unknown_person:unknown_domain" # type: ignore[attr-defined]
@@ -3588,7 +3712,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only whois a local user", channel.json_body["error"])
def test_get_whois_admin(self) -> None:
@@ -3600,7 +3724,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body)
@@ -3615,7 +3739,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.url,
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body)
@@ -3645,7 +3769,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
Try to get information of an user without authentication.
"""
channel = self.make_request(method, self.url)
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["POST", "DELETE"])
@@ -3656,18 +3780,18 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
other_user_token = self.login("user", "pass")
channel = self.make_request(method, self.url, access_token=other_user_token)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["POST", "DELETE"])
def test_user_is_not_local(self, method: str) -> None:
"""
- Tests that shadow-banning for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that shadow-banning for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
channel = self.make_request(method, url, access_token=self.admin_user_tok)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
def test_success(self) -> None:
"""
@@ -3680,7 +3804,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
self.assertFalse(result.shadow_banned)
channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual({}, channel.json_body)
# Ensure the user is shadow-banned (and the cache was cleared).
@@ -3692,7 +3816,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE", self.url, access_token=self.admin_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual({}, channel.json_body)
# Ensure the user is no longer shadow-banned (and the cache was cleared).
@@ -3727,7 +3851,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request(method, self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "POST", "DELETE"])
@@ -3743,13 +3867,13 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "POST", "DELETE"])
def test_user_does_not_exist(self, method: str) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit"
@@ -3759,7 +3883,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@parameterized.expand(
@@ -3771,7 +3895,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
)
def test_user_is_not_local(self, method: str, error_msg: str) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = (
"/_synapse/admin/v1/users/@unknown_person:unknown_domain/override_ratelimit"
@@ -3783,7 +3907,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(error_msg, channel.json_body["error"])
def test_invalid_parameter(self) -> None:
@@ -3798,7 +3922,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"messages_per_second": "string"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# messages_per_second is negative
@@ -3809,7 +3933,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"messages_per_second": -1},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# burst_count is a string
@@ -3820,7 +3944,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"burst_count": "string"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# burst_count is negative
@@ -3831,7 +3955,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"burst_count": -1},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_return_zero_when_null(self) -> None:
@@ -3856,7 +3980,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["messages_per_second"])
self.assertEqual(0, channel.json_body["burst_count"])
@@ -3870,7 +3994,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body)
@@ -3881,7 +4005,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"messages_per_second": 10, "burst_count": 11},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(10, channel.json_body["messages_per_second"])
self.assertEqual(11, channel.json_body["burst_count"])
@@ -3892,7 +4016,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"messages_per_second": 20, "burst_count": 21},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(20, channel.json_body["messages_per_second"])
self.assertEqual(21, channel.json_body["burst_count"])
@@ -3902,7 +4026,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(20, channel.json_body["messages_per_second"])
self.assertEqual(21, channel.json_body["burst_count"])
@@ -3912,7 +4036,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body)
@@ -3922,7 +4046,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body)
@@ -3947,7 +4071,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
"""Try to get information of a user without authentication."""
channel = self.make_request("GET", self.url, {})
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
@@ -3960,7 +4084,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self) -> None:
@@ -3973,7 +4097,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self) -> None:
@@ -3986,7 +4110,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only look up local users", channel.json_body["error"])
def test_success(self) -> None:
@@ -4007,7 +4131,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
{"a": 1}, channel.json_body["account_data"]["global"]["m.global"]
)
diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py
index b21f6d4689..30f12f1bff 100644
--- a/tests/rest/admin/test_username_available.py
+++ b/tests/rest/admin/test_username_available.py
@@ -11,9 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from http import HTTPStatus
-
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -40,7 +37,7 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
if username == "allowed":
return True
raise SynapseError(
- HTTPStatus.BAD_REQUEST,
+ 400,
"User ID already taken.",
errcode=Codes.USER_IN_USE,
)
@@ -50,27 +47,23 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
def test_username_available(self) -> None:
"""
- The endpoint should return a HTTPStatus.OK response if the username does not exist
+ The endpoint should return a 200 response if the username does not exist
"""
url = "%s?username=%s" % (self.url, "allowed")
channel = self.make_request("GET", url, access_token=self.admin_user_tok)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["available"])
def test_username_unavailable(self) -> None:
"""
- The endpoint should return a HTTPStatus.OK response if the username does not exist
+ The endpoint should return a 200 response if the username does not exist
"""
url = "%s?username=%s" % (self.url, "disallowed")
channel = self.make_request("GET", url, access_token=self.admin_user_tok)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], "M_USER_IN_USE")
self.assertEqual(channel.json_body["error"], "User ID already taken.")
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index a43a137273..c1a7fb2f8a 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import os
import re
from email.parser import Parser
+from http import HTTPStatus
from typing import Any, Dict, List, Optional, Union
from unittest.mock import Mock
@@ -95,10 +95,8 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
"""
body = {"type": "m.login.password", "user": username, "password": password}
- channel = self.make_request(
- "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
- )
- self.assertEqual(channel.code, 403, channel.result)
+ channel = self.make_request("POST", "/_matrix/client/r0/login", body)
+ self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
def test_basic_password_reset(self) -> None:
"""Test basic password reset flow"""
@@ -347,7 +345,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
shorthand=False,
)
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
# Now POST to the same endpoint, mimicking the same behaviour as clicking the
# password reset confirm button
@@ -362,7 +360,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
shorthand=False,
content_is_form=True,
)
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
def _get_link_from_email(self) -> str:
assert self.email_attempts, "No emails have been sent"
@@ -390,7 +388,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
new_password: str,
session_id: str,
client_secret: str,
- expected_code: int = 200,
+ expected_code: int = HTTPStatus.OK,
) -> None:
channel = self.make_request(
"POST",
@@ -479,20 +477,18 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
self.assertEqual(memberships[0].room_id, room_id, memberships)
def deactivate(self, user_id: str, tok: str) -> None:
- request_data = json.dumps(
- {
- "auth": {
- "type": "m.login.password",
- "user": user_id,
- "password": "test",
- },
- "erase": False,
- }
- )
+ request_data = {
+ "auth": {
+ "type": "m.login.password",
+ "user": user_id,
+ "password": "test",
+ },
+ "erase": False,
+ }
channel = self.make_request(
"POST", "account/deactivate", request_data, access_token=tok
)
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, 200, channel.json_body)
class WhoamiTestCase(unittest.HomeserverTestCase):
@@ -645,21 +641,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def test_add_email_no_at(self) -> None:
self._request_token_invalid_email(
"address-without-at.bar",
- expected_errcode=Codes.UNKNOWN,
+ expected_errcode=Codes.BAD_JSON,
expected_error="Unable to parse email address",
)
def test_add_email_two_at(self) -> None:
self._request_token_invalid_email(
"foo@foo@test.bar",
- expected_errcode=Codes.UNKNOWN,
+ expected_errcode=Codes.BAD_JSON,
expected_error="Unable to parse email address",
)
def test_add_email_bad_format(self) -> None:
self._request_token_invalid_email(
"user@bad.example.net@good.example.com",
- expected_errcode=Codes.UNKNOWN,
+ expected_errcode=Codes.BAD_JSON,
expected_error="Unable to parse email address",
)
@@ -715,7 +711,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user
@@ -725,7 +723,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_delete_email(self) -> None:
@@ -747,7 +745,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
{"medium": "email", "address": self.email},
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# Get user
channel = self.make_request(
@@ -756,7 +754,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_delete_email_if_disabled(self) -> None:
@@ -781,7 +779,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user
@@ -791,7 +791,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
@@ -817,7 +817,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user
@@ -827,7 +829,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_no_valid_token(self) -> None:
@@ -852,7 +854,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user
@@ -862,7 +866,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
@override_config({"next_link_domain_whitelist": None})
@@ -872,7 +876,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link="https://example.com/a/good/site",
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
@override_config({"next_link_domain_whitelist": None})
@@ -884,7 +888,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link="some-protocol://abcdefghijklmopqrstuvwxyz",
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
@override_config({"next_link_domain_whitelist": None})
@@ -895,7 +899,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link="file:///host/path",
- expect_code=400,
+ expect_code=HTTPStatus.BAD_REQUEST,
)
@override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
@@ -907,28 +911,28 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link=None,
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
self._request_token(
"something@example.com",
"some_secret",
next_link="https://example.com/some/good/page",
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
self._request_token(
"something@example.com",
"some_secret",
next_link="https://example.org/some/also/good/page",
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
self._request_token(
"something@example.com",
"some_secret",
next_link="https://bad.example.org/some/bad/page",
- expect_code=400,
+ expect_code=HTTPStatus.BAD_REQUEST,
)
@override_config({"next_link_domain_whitelist": []})
@@ -940,7 +944,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link="https://example.com/a/page",
- expect_code=400,
+ expect_code=HTTPStatus.BAD_REQUEST,
)
def _request_token(
@@ -948,8 +952,8 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
email: str,
client_secret: str,
next_link: Optional[str] = None,
- expect_code: int = 200,
- ) -> str:
+ expect_code: int = HTTPStatus.OK,
+ ) -> Optional[str]:
"""Request a validation token to add an email address to a user's account
Args:
@@ -959,7 +963,8 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
expect_code: Expected return code of the call
Returns:
- The ID of the new threepid validation session
+ The ID of the new threepid validation session, or None if the response
+ did not contain a session ID.
"""
body = {"client_secret": client_secret, "email": email, "send_attempt": 1}
if next_link:
@@ -992,16 +997,18 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
b"account/3pid/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1},
)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
self.assertEqual(expected_errcode, channel.json_body["errcode"])
- self.assertEqual(expected_error, channel.json_body["error"])
+ self.assertIn(expected_error, channel.json_body["error"])
def _validate_token(self, link: str) -> None:
# Remove the host
path = link.replace("https://example.com", "")
channel = self.make_request("GET", path, shorthand=False)
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
def _get_link_from_email(self) -> str:
assert self.email_attempts, "No emails have been sent"
@@ -1051,7 +1058,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# Get user
channel = self.make_request(
@@ -1060,7 +1067,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
@@ -1091,7 +1098,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
"""Tests that not providing any MXID raises an error."""
self._test_status(
users=None,
- expected_status_code=400,
+ expected_status_code=HTTPStatus.BAD_REQUEST,
expected_errcode=Codes.MISSING_PARAM,
)
@@ -1099,7 +1106,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
"""Tests that providing an invalid MXID raises an error."""
self._test_status(
users=["bad:test"],
- expected_status_code=400,
+ expected_status_code=HTTPStatus.BAD_REQUEST,
expected_errcode=Codes.INVALID_PARAM,
)
@@ -1285,7 +1292,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
def _test_status(
self,
users: Optional[List[str]],
- expected_status_code: int = 200,
+ expected_status_code: int = HTTPStatus.OK,
expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None,
expected_failures: Optional[List[str]] = None,
expected_errcode: Optional[str] = None,
diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py
index aca03afd0e..7a88aa2cda 100644
--- a/tests/rest/client/test_directory.py
+++ b/tests/rest/client/test_directory.py
@@ -11,11 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
+from synapse.appservice import ApplicationService
from synapse.rest import admin
from synapse.rest.client import directory, login, room
from synapse.server import HomeServer
@@ -96,8 +96,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
# We use deliberately a localpart under the length threshold so
# that we can make sure that the check is done on the whole alias.
- data = {"room_alias_name": random_string(256 - len(self.hs.hostname))}
- request_data = json.dumps(data)
+ request_data = {"room_alias_name": random_string(256 - len(self.hs.hostname))}
channel = self.make_request(
"POST", url, request_data, access_token=self.user_tok
)
@@ -109,8 +108,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
# Check with an alias of allowed length. There should already be
# a test that ensures it works in test_register.py, but let's be
# as cautious as possible here.
- data = {"room_alias_name": random_string(5)}
- request_data = json.dumps(data)
+ request_data = {"room_alias_name": random_string(5)}
channel = self.make_request(
"POST", url, request_data, access_token=self.user_tok
)
@@ -129,6 +127,38 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ def test_deleting_alias_via_directory_appservice(self) -> None:
+ user_id = "@as:test"
+ as_token = "i_am_an_app_service"
+
+ appservice = ApplicationService(
+ as_token,
+ id="1234",
+ namespaces={"aliases": [{"regex": "#asns-*", "exclusive": True}]},
+ sender=user_id,
+ )
+ self.hs.get_datastores().main.services_cache.append(appservice)
+
+ # Add an alias for the room, as the appservice
+ alias = RoomAlias(f"asns-{random_string(5)}", self.hs.hostname).to_string()
+ request_data = {"room_id": self.room_id}
+
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/r0/directory/room/{alias}",
+ request_data,
+ access_token=as_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ # Then try to remove the alias, as the appservice
+ channel = self.make_request(
+ "DELETE",
+ f"/_matrix/client/r0/directory/room/{alias}",
+ access_token=as_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
def test_deleting_nonexistant_alias(self) -> None:
# Check that no alias exists
alias = "#potato:test"
@@ -159,8 +189,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.hs.hostname,
)
- data = {"aliases": [self.random_alias(alias_length)]}
- request_data = json.dumps(data)
+ request_data = {"aliases": [self.random_alias(alias_length)]}
channel = self.make_request(
"PUT", url, request_data, access_token=self.user_tok
@@ -172,8 +201,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
) -> str:
alias = self.random_alias(alias_length)
url = "/_matrix/client/r0/directory/room/%s" % alias
- data = {"room_id": self.room_id}
- request_data = json.dumps(data)
+ request_data = {"room_id": self.room_id}
channel = self.make_request(
"PUT", url, request_data, access_token=self.user_tok
@@ -181,6 +209,19 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, expected_code, channel.result)
return alias
+ def test_invalid_alias(self) -> None:
+ alias = "#potato"
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/r0/directory/room/{alias}",
+ access_token=self.user_tok,
+ )
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertIn("error", channel.json_body, channel.json_body)
+ self.assertEqual(
+ channel.json_body["errcode"], "M_INVALID_PARAM", channel.json_body
+ )
+
def random_alias(self, length: int) -> str:
return RoomAlias(random_string(length), self.hs.hostname).to_string()
diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index 823e8ab8c4..afc8d641be 100644
--- a/tests/rest/client/test_filter.py
+++ b/tests/rest/client/test_filter.py
@@ -43,7 +43,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.EXAMPLE_FILTER_JSON,
)
- self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"filter_id": "0"})
filter = self.get_success(
self.store.get_user_filter(user_localpart="apple", filter_id=0)
@@ -58,7 +58,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.EXAMPLE_FILTER_JSON,
)
- self.assertEqual(channel.result["code"], b"403")
+ self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
def test_add_filter_non_local_user(self) -> None:
@@ -71,7 +71,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
)
self.hs.is_mine = _is_mine
- self.assertEqual(channel.result["code"], b"403")
+ self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
def test_get_filter(self) -> None:
@@ -85,7 +85,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
)
- self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, self.EXAMPLE_FILTER)
def test_get_filter_non_existant(self) -> None:
@@ -93,7 +93,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
)
- self.assertEqual(channel.result["code"], b"404")
+ self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
# Currently invalid params do not have an appropriate errcode
@@ -103,7 +103,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
)
- self.assertEqual(channel.result["code"], b"400")
+ self.assertEqual(channel.code, 400)
# No ID also returns an invalid_id error
def test_get_filter_no_id(self) -> None:
@@ -111,4 +111,4 @@ class FilterTestCase(unittest.HomeserverTestCase):
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
)
- self.assertEqual(channel.result["code"], b"400")
+ self.assertEqual(channel.code, 400)
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index 299b9d21e2..b0c8215744 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
@@ -26,7 +25,6 @@ from tests import unittest
class IdentityTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -34,7 +32,6 @@ class IdentityTestCase(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
config = self.default_config()
config["enable_3pid_lookup"] = False
self.hs = self.setup_test_homeserver(config=config)
@@ -51,12 +48,12 @@ class IdentityTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
room_id = channel.json_body["room_id"]
- params = {
+ request_data = {
"id_server": "testis",
"medium": "email",
"address": "test@example.com",
+ "id_access_token": tok,
}
- request_data = json.dumps(params)
request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
channel = self.make_request(
b"POST", request_url, request_data, access_token=tok
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index f4ea1209d9..e2a4d98275 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import time
import urllib.parse
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional
from unittest.mock import Mock
from urllib.parse import urlencode
@@ -41,7 +40,7 @@ from tests.test_utils.html_parsers import TestHtmlParser
from tests.unittest import HomeserverTestCase, override_config, skip_unless
try:
- import jwt
+ from authlib.jose import jwk, jwt
HAS_JWT = True
except ImportError:
@@ -134,10 +133,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5:
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min.
@@ -152,7 +151,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request(b"POST", LOGIN_URL, params)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
@override_config(
{
@@ -179,10 +178,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5:
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min.
@@ -197,7 +196,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request(b"POST", LOGIN_URL, params)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
@override_config(
{
@@ -224,10 +223,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5:
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min.
@@ -242,7 +241,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request(b"POST", LOGIN_URL, params)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
@override_config({"session_lifetime": "24h"})
def test_soft_logout(self) -> None:
@@ -250,7 +249,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# we shouldn't be able to make requests without an access token
channel = self.make_request(b"GET", TEST_URL)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN")
# log in as normal
@@ -354,7 +353,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# Now try to hard logout this session
channel = self.make_request(b"POST", "/logout", access_token=access_token)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
@override_config({"session_lifetime": "24h"})
def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(
@@ -380,7 +379,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# Now try to hard log out all of the user's sessions
channel = self.make_request(b"POST", "/logout/all", access_token=access_token)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
def test_login_with_overly_long_device_id_fails(self) -> None:
self.register_user("mickey", "cheese")
@@ -399,7 +398,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
"/_matrix/client/v3/login",
- json.dumps(body).encode("utf8"),
+ body,
custom_headers=None,
)
@@ -841,7 +840,7 @@ class CASTestCase(unittest.HomeserverTestCase):
self.assertIn(b"SSO account deactivated", channel.result["body"])
-@skip_unless(HAS_JWT, "requires jwt")
+@skip_unless(HAS_JWT, "requires authlib")
class JWTTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -866,11 +865,9 @@ class JWTTestCase(unittest.HomeserverTestCase):
return config
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
- # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
- result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm)
- if isinstance(result, bytes):
- return result.decode("ascii")
- return result
+ header = {"alg": self.jwt_algorithm}
+ result: bytes = jwt.encode(header, payload, secret)
+ return result.decode("ascii")
def jwt_login(self, *args: Any) -> FakeChannel:
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
@@ -880,17 +877,17 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_valid_registered(self) -> None:
self.register_user("kermit", "monkey")
channel = self.jwt_login({"sub": "kermit"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
def test_login_jwt_valid_unregistered(self) -> None:
channel = self.jwt_login({"sub": "frog"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test")
def test_login_jwt_invalid_signature(self) -> None:
channel = self.jwt_login({"sub": "frog"}, "notsecret")
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
@@ -899,25 +896,26 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_expired(self) -> None:
channel = self.jwt_login({"sub": "frog", "exp": 864000})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
- channel.json_body["error"], "JWT validation failed: Signature has expired"
+ channel.json_body["error"],
+ "JWT validation failed: expired_token: The token is expired",
)
def test_login_jwt_not_before(self) -> None:
now = int(time.time())
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
- "JWT validation failed: The token is not yet valid (nbf)",
+ "JWT validation failed: invalid_token: The token is not valid yet",
)
def test_login_no_sub(self) -> None:
channel = self.jwt_login({"username": "root"})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Invalid JWT")
@@ -926,30 +924,31 @@ class JWTTestCase(unittest.HomeserverTestCase):
"""Test validating the issuer claim."""
# A valid issuer.
channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
# An invalid issuer.
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
- channel.json_body["error"], "JWT validation failed: Invalid issuer"
+ channel.json_body["error"],
+ 'JWT validation failed: invalid_claim: Invalid claim "iss"',
)
# Not providing an issuer.
channel = self.jwt_login({"sub": "kermit"})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
- 'JWT validation failed: Token is missing the "iss" claim',
+ 'JWT validation failed: missing_claim: Missing "iss" claim',
)
def test_login_iss_no_config(self) -> None:
"""Test providing an issuer claim without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
@override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
@@ -957,52 +956,54 @@ class JWTTestCase(unittest.HomeserverTestCase):
"""Test validating the audience claim."""
# A valid audience.
channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
# An invalid audience.
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
- channel.json_body["error"], "JWT validation failed: Invalid audience"
+ channel.json_body["error"],
+ 'JWT validation failed: invalid_claim: Invalid claim "aud"',
)
# Not providing an audience.
channel = self.jwt_login({"sub": "kermit"})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
- 'JWT validation failed: Token is missing the "aud" claim',
+ 'JWT validation failed: missing_claim: Missing "aud" claim',
)
def test_login_aud_no_config(self) -> None:
"""Test providing an audience without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
- channel.json_body["error"], "JWT validation failed: Invalid audience"
+ channel.json_body["error"],
+ 'JWT validation failed: invalid_claim: Invalid claim "aud"',
)
def test_login_default_sub(self) -> None:
"""Test reading user ID from the default subject claim."""
channel = self.jwt_login({"sub": "kermit"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
@override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
def test_login_custom_sub(self) -> None:
"""Test reading user ID from a custom subject claim."""
channel = self.jwt_login({"username": "frog"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test")
def test_login_no_token(self) -> None:
params = {"type": "org.matrix.login.jwt"}
channel = self.make_request(b"POST", LOGIN_URL, params)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
@@ -1010,7 +1011,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
# RSS256, with a public key configured in synapse as "jwt_secret", and tokens
# signed by the private key.
-@skip_unless(HAS_JWT, "requires jwt")
+@skip_unless(HAS_JWT, "requires authlib")
class JWTPubKeyTestCase(unittest.HomeserverTestCase):
servlets = [
login.register_servlets,
@@ -1071,11 +1072,11 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
return config
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
- # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
- result: Union[bytes, str] = jwt.encode(payload, secret, "RS256")
- if isinstance(result, bytes):
- return result.decode("ascii")
- return result
+ header = {"alg": "RS256"}
+ if secret.startswith("-----BEGIN RSA PRIVATE KEY-----"):
+ secret = jwk.dumps(secret, kty="RSA")
+ result: bytes = jwt.encode(header, payload, secret)
+ return result.decode("ascii")
def jwt_login(self, *args: Any) -> FakeChannel:
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
@@ -1084,12 +1085,12 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
def test_login_jwt_valid(self) -> None:
channel = self.jwt_login({"sub": "kermit"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
def test_login_jwt_invalid_signature(self) -> None:
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
@@ -1150,7 +1151,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.service.token
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
def test_login_appservice_user_bot(self) -> None:
"""Test that the appservice bot can use /login"""
@@ -1164,7 +1165,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.service.token
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
def test_login_appservice_wrong_user(self) -> None:
"""Test that non-as users cannot login with the as token"""
@@ -1178,7 +1179,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.service.token
)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
def test_login_appservice_wrong_as(self) -> None:
"""Test that as users cannot login with wrong as token"""
@@ -1192,7 +1193,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.another_service.token
)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
def test_login_appservice_no_token(self) -> None:
"""Test that users must provide a token when using the appservice
@@ -1206,7 +1207,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request(b"POST", LOGIN_URL, params)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
@skip_unless(HAS_OIDC, "requires OIDC")
diff --git a/tests/rest/client/test_models.py b/tests/rest/client/test_models.py
new file mode 100644
index 0000000000..a9da00665e
--- /dev/null
+++ b/tests/rest/client/test_models.py
@@ -0,0 +1,53 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import unittest
+
+from pydantic import ValidationError
+
+from synapse.rest.client.models import EmailRequestTokenBody
+
+
+class EmailRequestTokenBodyTestCase(unittest.TestCase):
+ base_request = {
+ "client_secret": "hunter2",
+ "email": "alice@wonderland.com",
+ "send_attempt": 1,
+ }
+
+ def test_token_required_if_id_server_provided(self) -> None:
+ with self.assertRaises(ValidationError):
+ EmailRequestTokenBody.parse_obj(
+ {
+ **self.base_request,
+ "id_server": "identity.wonderland.com",
+ }
+ )
+ with self.assertRaises(ValidationError):
+ EmailRequestTokenBody.parse_obj(
+ {
+ **self.base_request,
+ "id_server": "identity.wonderland.com",
+ "id_access_token": None,
+ }
+ )
+
+ def test_token_typechecked_when_id_server_provided(self) -> None:
+ with self.assertRaises(ValidationError):
+ EmailRequestTokenBody.parse_obj(
+ {
+ **self.base_request,
+ "id_server": "identity.wonderland.com",
+ "id_access_token": 1337,
+ }
+ )
diff --git a/tests/rest/client/test_password_policy.py b/tests/rest/client/test_password_policy.py
index 3a74d2e96c..e19d21d6ee 100644
--- a/tests/rest/client/test_password_policy.py
+++ b/tests/rest/client/test_password_policy.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
@@ -89,7 +88,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
)
def test_password_too_short(self) -> None:
- request_data = json.dumps({"username": "kermit", "password": "shorty"})
+ request_data = {"username": "kermit", "password": "shorty"}
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
@@ -100,7 +99,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
)
def test_password_no_digit(self) -> None:
- request_data = json.dumps({"username": "kermit", "password": "longerpassword"})
+ request_data = {"username": "kermit", "password": "longerpassword"}
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
@@ -111,7 +110,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
)
def test_password_no_symbol(self) -> None:
- request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"})
+ request_data = {"username": "kermit", "password": "l0ngerpassword"}
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
@@ -122,7 +121,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
)
def test_password_no_uppercase(self) -> None:
- request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"})
+ request_data = {"username": "kermit", "password": "l0ngerpassword!"}
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
@@ -133,7 +132,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
)
def test_password_no_lowercase(self) -> None:
- request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"})
+ request_data = {"username": "kermit", "password": "L0NGERPASSWORD!"}
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
@@ -144,7 +143,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
)
def test_password_compliant(self) -> None:
- request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"})
+ request_data = {"username": "kermit", "password": "L0ngerpassword!"}
channel = self.make_request("POST", self.register_url, request_data)
# Getting a 401 here means the password has passed validation and the server has
@@ -161,16 +160,14 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
user_id = self.register_user("kermit", compliant_password)
tok = self.login("kermit", compliant_password)
- request_data = json.dumps(
- {
- "new_password": not_compliant_password,
- "auth": {
- "password": compliant_password,
- "type": LoginType.PASSWORD,
- "user": user_id,
- },
- }
- )
+ request_data = {
+ "new_password": not_compliant_password,
+ "auth": {
+ "password": compliant_password,
+ "type": LoginType.PASSWORD,
+ "user": user_id,
+ },
+ }
channel = self.make_request(
"POST",
"/_matrix/client/r0/account/password",
diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index 77c3ced42e..8de5a342ae 100644
--- a/tests/rest/client/test_profile.py
+++ b/tests/rest/client/test_profile.py
@@ -13,6 +13,8 @@
# limitations under the License.
"""Tests REST events for /profile paths."""
+import urllib.parse
+from http import HTTPStatus
from typing import Any, Dict, Optional
from twisted.test.proto_helpers import MemoryReactor
@@ -49,6 +51,12 @@ class ProfileTestCase(unittest.HomeserverTestCase):
res = self._get_displayname()
self.assertEqual(res, "owner")
+ def test_get_displayname_rejects_bad_username(self) -> None:
+ channel = self.make_request(
+ "GET", f"/profile/{urllib.parse.quote('@alice:')}/displayname"
+ )
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+
def test_set_displayname(self) -> None:
channel = self.make_request(
"PUT",
@@ -145,18 +153,22 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 400, channel.result)
- def _get_displayname(self, name: Optional[str] = None) -> str:
+ def _get_displayname(self, name: Optional[str] = None) -> Optional[str]:
channel = self.make_request(
"GET", "/profile/%s/displayname" % (name or self.owner,)
)
self.assertEqual(channel.code, 200, channel.result)
- return channel.json_body["displayname"]
+ # FIXME: If a user has no displayname set, Synapse returns 200 and omits a
+ # displayname from the response. This contradicts the spec, see #13137.
+ return channel.json_body.get("displayname")
- def _get_avatar_url(self, name: Optional[str] = None) -> str:
+ def _get_avatar_url(self, name: Optional[str] = None) -> Optional[str]:
channel = self.make_request(
"GET", "/profile/%s/avatar_url" % (name or self.owner,)
)
self.assertEqual(channel.code, 200, channel.result)
+ # FIXME: If a user has no avatar set, Synapse returns 200 and omits an
+ # avatar_url from the response. This contradicts the spec, see #13137.
return channel.json_body.get("avatar_url")
@unittest.override_config({"max_avatar_size": 50})
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index 7401b5e0c0..be4c67d68e 100644
--- a/tests/rest/client/test_redactions.py
+++ b/tests/rest/client/test_redactions.py
@@ -76,12 +76,12 @@ class RedactionsTestCase(HomeserverTestCase):
path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id)
channel = self.make_request("POST", path, content={}, access_token=access_token)
- self.assertEqual(int(channel.result["code"]), expect_code)
+ self.assertEqual(channel.code, expect_code)
return channel.json_body
def _sync_room_timeline(self, access_token: str, room_id: str) -> List[JsonDict]:
channel = self.make_request("GET", "sync", access_token=self.mod_access_token)
- self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.code, 200)
room_sync = channel.json_body["rooms"]["join"][room_id]
return room_sync["timeline"]["events"]
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index afb08b2736..b781875d52 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
-import json
import os
from typing import Any, Dict, List, Tuple
@@ -62,15 +61,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
self.hs.get_datastores().main.services_cache.append(appservice)
- request_data = json.dumps(
- {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
- )
+ request_data = {
+ "username": "as_user_kermit",
+ "type": APP_SERVICE_REGISTRATION_TYPE,
+ }
channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
det_data = {"user_id": user_id, "home_server": self.hs.hostname}
self.assertDictContainsSubset(det_data, channel.json_body)
@@ -85,49 +85,46 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
self.hs.get_datastores().main.services_cache.append(appservice)
- request_data = json.dumps({"username": "as_user_kermit"})
+ request_data = {"username": "as_user_kermit"}
channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
- self.assertEqual(channel.result["code"], b"400", channel.result)
+ self.assertEqual(channel.code, 400, msg=channel.result)
def test_POST_appservice_registration_invalid(self) -> None:
self.appservice = None # no application service exists
- request_data = json.dumps(
- {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
- )
+ request_data = {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
def test_POST_bad_password(self) -> None:
- request_data = json.dumps({"username": "kermit", "password": 666})
+ request_data = {"username": "kermit", "password": 666}
channel = self.make_request(b"POST", self.url, request_data)
- self.assertEqual(channel.result["code"], b"400", channel.result)
+ self.assertEqual(channel.code, 400, msg=channel.result)
self.assertEqual(channel.json_body["error"], "Invalid password")
def test_POST_bad_username(self) -> None:
- request_data = json.dumps({"username": 777, "password": "monkey"})
+ request_data = {"username": 777, "password": "monkey"}
channel = self.make_request(b"POST", self.url, request_data)
- self.assertEqual(channel.result["code"], b"400", channel.result)
+ self.assertEqual(channel.code, 400, msg=channel.result)
self.assertEqual(channel.json_body["error"], "Invalid username")
def test_POST_user_valid(self) -> None:
user_id = "@kermit:test"
device_id = "frogfone"
- params = {
+ request_data = {
"username": "kermit",
"password": "monkey",
"device_id": device_id,
"auth": {"type": LoginType.DUMMY},
}
- request_data = json.dumps(params)
channel = self.make_request(b"POST", self.url, request_data)
det_data = {
@@ -135,17 +132,17 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"home_server": self.hs.hostname,
"device_id": device_id,
}
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
@override_config({"enable_registration": False})
def test_POST_disabled_registration(self) -> None:
- request_data = json.dumps({"username": "kermit", "password": "monkey"})
+ request_data = {"username": "kermit", "password": "monkey"}
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
channel = self.make_request(b"POST", self.url, request_data)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["error"], "Registration has been disabled")
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@@ -156,7 +153,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"}
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
def test_POST_disabled_guest_registration(self) -> None:
@@ -164,7 +161,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["error"], "Guest access is disabled")
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
@@ -174,40 +171,39 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", url, b"{}")
if i == 5:
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
def test_POST_ratelimiting(self) -> None:
for i in range(0, 6):
- params = {
+ request_data = {
"username": "kermit" + str(i),
"password": "monkey",
"device_id": "frogfone",
"auth": {"type": LoginType.DUMMY},
}
- request_data = json.dumps(params)
channel = self.make_request(b"POST", self.url, request_data)
if i == 5:
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
@override_config({"registration_requires_token": True})
def test_POST_registration_requires_token(self) -> None:
@@ -234,8 +230,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
}
# Request without auth to get flows and session
- channel = self.make_request(b"POST", self.url, json.dumps(params))
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ channel = self.make_request(b"POST", self.url, params)
+ self.assertEqual(channel.code, 401, msg=channel.result)
flows = channel.json_body["flows"]
# Synapse adds a dummy stage to differentiate flows where otherwise one
# flow would be a subset of another flow.
@@ -251,9 +247,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"token": token,
"session": session,
}
- request_data = json.dumps(params)
- channel = self.make_request(b"POST", self.url, request_data)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ channel = self.make_request(b"POST", self.url, params)
+ self.assertEqual(channel.code, 401, msg=channel.result)
completed = channel.json_body["completed"]
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
@@ -262,14 +257,13 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"type": LoginType.DUMMY,
"session": session,
}
- request_data = json.dumps(params)
- channel = self.make_request(b"POST", self.url, request_data)
+ channel = self.make_request(b"POST", self.url, params)
det_data = {
"user_id": f"@{username}:{self.hs.hostname}",
"home_server": self.hs.hostname,
"device_id": device_id,
}
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
# Check the `completed` counter has been incremented and pending is 0
@@ -290,7 +284,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"password": "monkey",
}
# Request without auth to get session
- channel = self.make_request(b"POST", self.url, json.dumps(params))
+ channel = self.make_request(b"POST", self.url, params)
session = channel.json_body["session"]
# Test with token param missing (invalid)
@@ -298,22 +292,22 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"type": LoginType.REGISTRATION_TOKEN,
"session": session,
}
- channel = self.make_request(b"POST", self.url, json.dumps(params))
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ channel = self.make_request(b"POST", self.url, params)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM)
self.assertEqual(channel.json_body["completed"], [])
# Test with non-string (invalid)
params["auth"]["token"] = 1234
- channel = self.make_request(b"POST", self.url, json.dumps(params))
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ channel = self.make_request(b"POST", self.url, params)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
self.assertEqual(channel.json_body["completed"], [])
# Test with unknown token (invalid)
params["auth"]["token"] = "1234"
- channel = self.make_request(b"POST", self.url, json.dumps(params))
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ channel = self.make_request(b"POST", self.url, params)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], [])
@@ -337,9 +331,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
params1: JsonDict = {"username": "bert", "password": "monkey"}
params2: JsonDict = {"username": "ernie", "password": "monkey"}
# Do 2 requests without auth to get two session IDs
- channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+ channel1 = self.make_request(b"POST", self.url, params1)
session1 = channel1.json_body["session"]
- channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+ channel2 = self.make_request(b"POST", self.url, params2)
session2 = channel2.json_body["session"]
# Use token with session1 and check `pending` is 1
@@ -348,9 +342,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"token": token,
"session": session1,
}
- self.make_request(b"POST", self.url, json.dumps(params1))
+ self.make_request(b"POST", self.url, params1)
# Repeat request to make sure pending isn't increased again
- self.make_request(b"POST", self.url, json.dumps(params1))
+ self.make_request(b"POST", self.url, params1)
pending = self.get_success(
store.db_pool.simple_select_one_onecol(
"registration_tokens",
@@ -366,14 +360,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"token": token,
"session": session2,
}
- channel = self.make_request(b"POST", self.url, json.dumps(params2))
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ channel = self.make_request(b"POST", self.url, params2)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], [])
# Complete registration with session1
params1["auth"]["type"] = LoginType.DUMMY
- self.make_request(b"POST", self.url, json.dumps(params1))
+ self.make_request(b"POST", self.url, params1)
# Check pending=0 and completed=1
res = self.get_success(
store.db_pool.simple_select_one(
@@ -386,8 +380,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(res["completed"], 1)
# Check auth still fails when using token with session2
- channel = self.make_request(b"POST", self.url, json.dumps(params2))
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ channel = self.make_request(b"POST", self.url, params2)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], [])
@@ -411,7 +405,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
params: JsonDict = {"username": "kermit", "password": "monkey"}
# Request without auth to get session
- channel = self.make_request(b"POST", self.url, json.dumps(params))
+ channel = self.make_request(b"POST", self.url, params)
session = channel.json_body["session"]
# Check authentication fails with expired token
@@ -420,8 +414,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"token": token,
"session": session,
}
- channel = self.make_request(b"POST", self.url, json.dumps(params))
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ channel = self.make_request(b"POST", self.url, params)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], [])
@@ -435,7 +429,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
# Check authentication succeeds
- channel = self.make_request(b"POST", self.url, json.dumps(params))
+ channel = self.make_request(b"POST", self.url, params)
completed = channel.json_body["completed"]
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
@@ -460,9 +454,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
# Do 2 requests without auth to get two session IDs
params1: JsonDict = {"username": "bert", "password": "monkey"}
params2: JsonDict = {"username": "ernie", "password": "monkey"}
- channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+ channel1 = self.make_request(b"POST", self.url, params1)
session1 = channel1.json_body["session"]
- channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+ channel2 = self.make_request(b"POST", self.url, params2)
session2 = channel2.json_body["session"]
# Use token with both sessions
@@ -471,18 +465,18 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"token": token,
"session": session1,
}
- self.make_request(b"POST", self.url, json.dumps(params1))
+ self.make_request(b"POST", self.url, params1)
params2["auth"] = {
"type": LoginType.REGISTRATION_TOKEN,
"token": token,
"session": session2,
}
- self.make_request(b"POST", self.url, json.dumps(params2))
+ self.make_request(b"POST", self.url, params2)
# Complete registration with session1
params1["auth"]["type"] = LoginType.DUMMY
- self.make_request(b"POST", self.url, json.dumps(params1))
+ self.make_request(b"POST", self.url, params1)
# Check `result` of registration token stage for session1 is `True`
result1 = self.get_success(
@@ -550,7 +544,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
# Do request without auth to get a session ID
params: JsonDict = {"username": "kermit", "password": "monkey"}
- channel = self.make_request(b"POST", self.url, json.dumps(params))
+ channel = self.make_request(b"POST", self.url, params)
session = channel.json_body["session"]
# Use token
@@ -559,7 +553,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"token": token,
"session": session,
}
- self.make_request(b"POST", self.url, json.dumps(params))
+ self.make_request(b"POST", self.url, params)
# Delete token
self.get_success(
@@ -576,7 +570,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
def test_advertised_flows(self) -> None:
channel = self.make_request(b"POST", self.url, b"{}")
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
flows = channel.json_body["flows"]
# with the stock config, we only expect the dummy flow
@@ -592,14 +586,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"require_at_registration": True,
},
"account_threepid_delegates": {
- "email": "https://id_server",
"msisdn": "https://id_server",
},
+ "email": {"notif_from": "Synapse <synapse@example.com>"},
}
)
def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None:
channel = self.make_request(b"POST", self.url, b"{}")
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
flows = channel.json_body["flows"]
self.assertCountEqual(
@@ -631,7 +625,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
def test_advertised_flows_no_msisdn_email_required(self) -> None:
channel = self.make_request(b"POST", self.url, b"{}")
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
flows = channel.json_body["flows"]
# with the stock config, we expect all four combinations of 3pid
@@ -803,13 +797,13 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
# endpoint.
channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
)
@@ -827,15 +821,14 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
admin_tok = self.login("admin", "adminpassword")
url = "/_synapse/admin/v1/account_validity/validity"
- params = {"user_id": user_id}
- request_data = json.dumps(params)
+ request_data = {"user_id": user_id}
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
def test_manual_expire(self) -> None:
user_id = self.register_user("kermit", "monkey")
@@ -845,19 +838,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
admin_tok = self.login("admin", "adminpassword")
url = "/_synapse/admin/v1/account_validity/validity"
- params = {
+ request_data = {
"user_id": user_id,
"expiration_ts": 0,
"enable_renewal_emails": False,
}
- request_data = json.dumps(params)
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
)
@@ -870,25 +862,24 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
admin_tok = self.login("admin", "adminpassword")
url = "/_synapse/admin/v1/account_validity/validity"
- params = {
+ request_data = {
"user_id": user_id,
"expiration_ts": 0,
"enable_renewal_emails": False,
}
- request_data = json.dumps(params)
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# Try to log the user out
channel = self.make_request(b"POST", "/logout", access_token=tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# Log the user in again (allowed for expired accounts)
tok = self.login("kermit", "monkey")
# Try to log out all of the user's sessions
channel = self.make_request(b"POST", "/logout/all", access_token=tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
@@ -963,7 +954,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
channel = self.make_request(b"GET", url)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# Check that we're getting HTML back.
content_type = channel.headers.getRawHeaders(b"Content-Type")
@@ -981,7 +972,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# Move 1 day forward. Try to renew with the same token again.
url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
channel = self.make_request(b"GET", url)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# Check that we're getting HTML back.
content_type = channel.headers.getRawHeaders(b"Content-Type")
@@ -1001,14 +992,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# succeed.
self.reactor.advance(datetime.timedelta(days=3).total_seconds())
channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
def test_renewal_invalid_token(self) -> None:
# Hit the renewal endpoint with an invalid token and check that it behaves as
# expected, i.e. that it responds with 404 Not Found and the correct HTML.
url = "/_matrix/client/unstable/account_validity/renew?token=123"
channel = self.make_request(b"GET", url)
- self.assertEqual(channel.result["code"], b"404", channel.result)
+ self.assertEqual(channel.code, 404, msg=channel.result)
# Check that we're getting HTML back.
content_type = channel.headers.getRawHeaders(b"Content-Type")
@@ -1032,7 +1023,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"/_matrix/client/unstable/account_validity/send_mail",
access_token=tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(len(self.email_attempts), 1)
@@ -1041,16 +1032,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
(user_id, tok) = self.create_user()
- request_data = json.dumps(
- {
- "auth": {
- "type": "m.login.password",
- "user": user_id,
- "password": "monkey",
- },
- "erase": False,
- }
- )
+ request_data = {
+ "auth": {
+ "type": "m.login.password",
+ "user": user_id,
+ "password": "monkey",
+ },
+ "erase": False,
+ }
channel = self.make_request(
"POST", "account/deactivate", request_data, access_token=tok
)
@@ -1107,7 +1096,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"/_matrix/client/unstable/account_validity/send_mail",
access_token=tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(len(self.email_attempts), 1)
@@ -1187,7 +1176,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
b"GET",
f"{self.url}?token={token}",
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["valid"], True)
def test_GET_token_invalid(self) -> None:
@@ -1196,7 +1185,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
b"GET",
f"{self.url}?token={token}",
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["valid"], False)
@override_config(
@@ -1212,10 +1201,10 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
)
if i == 5:
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
@@ -1223,4 +1212,4 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
b"GET",
f"{self.url}?token={token}",
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 62e4db23ef..651f4f415d 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -728,6 +728,7 @@ class RelationsTestCase(BaseRelationsTestCase):
class RelationPaginationTestCase(BaseRelationsTestCase):
+ @unittest.override_config({"experimental_features": {"msc3715_enabled": True}})
def test_basic_paginate_relations(self) -> None:
"""Tests that calling pagination API correctly the latest relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
@@ -799,7 +800,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
)
expected_event_ids.append(channel.json_body["event_id"])
- prev_token = ""
+ prev_token: Optional[str] = ""
found_event_ids: List[str] = []
for _ in range(20):
from_token = ""
@@ -998,7 +999,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations,
)
- self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 6)
+ self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
def test_annotation_to_annotation(self) -> None:
"""Any relation to an annotation should be ignored."""
@@ -1034,7 +1035,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations,
)
- self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6)
+ self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
def test_thread(self) -> None:
"""
@@ -1059,6 +1060,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
participated, bundled_aggregations.get("current_user_participated")
)
# The latest thread event has some fields that don't matter.
+ self.assertIn("latest_event", bundled_aggregations)
self.assert_dict(
{
"content": {
@@ -1071,28 +1073,28 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
"sender": self.user2_id,
"type": "m.room.test",
},
- bundled_aggregations.get("latest_event"),
+ bundled_aggregations["latest_event"],
)
return assert_thread
# The "user" sent the root event and is making queries for the bundled
# aggregations: they have participated.
- self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8)
+ self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 9)
# The "user2" sent replies in the thread and is making queries for the
# bundled aggregations: they have participated.
#
# Note that this re-uses some cached values, so the total number of
# queries is much smaller.
self._test_bundled_aggregations(
- RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token
+ RelationTypes.THREAD, _gen_assert(True), 3, access_token=self.user2_token
)
# A user with no interactions with the thread: they have not participated.
user3_id, user3_token = self._create_user("charlie")
self.helper.join(self.room, user=user3_id, tok=user3_token)
self._test_bundled_aggregations(
- RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token
+ RelationTypes.THREAD, _gen_assert(False), 3, access_token=user3_token
)
def test_thread_with_bundled_aggregations_for_latest(self) -> None:
@@ -1111,6 +1113,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
self.assertEqual(2, bundled_aggregations.get("count"))
self.assertTrue(bundled_aggregations.get("current_user_participated"))
# The latest thread event has some fields that don't matter.
+ self.assertIn("latest_event", bundled_aggregations)
self.assert_dict(
{
"content": {
@@ -1123,7 +1126,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
"sender": self.user_id,
"type": "m.room.test",
},
- bundled_aggregations.get("latest_event"),
+ bundled_aggregations["latest_event"],
)
# Check the unsigned field on the latest event.
self.assert_dict(
@@ -1139,7 +1142,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations["latest_event"].get("unsigned"),
)
- self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8)
+ self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
def test_nested_thread(self) -> None:
"""
diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py
index 20a259fc43..7cb1017a4a 100644
--- a/tests/rest/client/test_report_event.py
+++ b/tests/rest/client/test_report_event.py
@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
-
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -77,11 +75,6 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
def _assert_status(self, response_status: int, data: JsonDict) -> None:
channel = self.make_request(
- "POST",
- self.report_path,
- json.dumps(data),
- access_token=self.other_user_tok,
- )
- self.assertEqual(
- response_status, int(channel.result["code"]), msg=channel.result["body"]
+ "POST", self.report_path, data, access_token=self.other_user_tok
)
+ self.assertEqual(response_status, channel.code, msg=channel.result["body"])
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index ac9c113354..9c8c1889d3 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -20,7 +20,7 @@ from synapse.api.constants import EventTypes
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
-from synapse.types import JsonDict
+from synapse.types import JsonDict, create_requester
from synapse.util import Clock
from synapse.visibility import filter_events_for_client
@@ -188,7 +188,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
message_handler = self.hs.get_message_handler()
create_event = self.get_success(
message_handler.get_room_data(
- self.user_id, room_id, EventTypes.Create, state_key=""
+ create_requester(self.user_id), room_id, EventTypes.Create, state_key=""
)
)
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index f523d89b8f..c7eb88d33f 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -18,10 +18,14 @@
"""Tests REST events for /rooms paths."""
import json
-from typing import Any, Dict, Iterable, List, Optional
+from http import HTTPStatus
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from unittest.mock import Mock, call
from urllib import parse as urlparse
+from parameterized import param, parameterized
+from typing_extensions import Literal
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -30,7 +34,9 @@ from synapse.api.constants import (
EventContentFields,
EventTypes,
Membership,
+ PublicRoomsFilterFields,
RelationTypes,
+ RoomTypes,
)
from synapse.api.errors import Codes, HttpResponseException
from synapse.handlers.pagination import PurgeStatus
@@ -42,6 +48,7 @@ from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
+from tests.http.server._base import make_request_with_cancellation_test
from tests.test_utils import make_awaitable
PATH_PREFIX = b"/_matrix/client/api/v1"
@@ -98,7 +105,7 @@ class RoomPermissionsTestCase(RoomBase):
channel = self.make_request(
"PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}'
)
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
# set topic for public room
channel = self.make_request(
@@ -106,7 +113,7 @@ class RoomPermissionsTestCase(RoomBase):
("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"),
b'{"topic":"Public Room Topic"}',
)
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
# auth as user_id now
self.helper.auth_user_id = self.user_id
@@ -128,28 +135,28 @@ class RoomPermissionsTestCase(RoomBase):
"/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,),
msg_content,
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# send message in created room not joined (no state), expect 403
channel = self.make_request("PUT", send_msg_path(), msg_content)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# send message in created room and invited, expect 403
self.helper.invite(
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
)
channel = self.make_request("PUT", send_msg_path(), msg_content)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# send message in created room and joined, expect 200
self.helper.join(room=self.created_rmid, user=self.user_id)
channel = self.make_request("PUT", send_msg_path(), msg_content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# send message in created room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id)
channel = self.make_request("PUT", send_msg_path(), msg_content)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def test_topic_perms(self) -> None:
topic_content = b'{"topic":"My Topic Name"}'
@@ -159,28 +166,28 @@ class RoomPermissionsTestCase(RoomBase):
channel = self.make_request(
"PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
channel = self.make_request(
"GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# set/get topic in created PRIVATE room not joined, expect 403
channel = self.make_request("PUT", topic_path, topic_content)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", topic_path)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# set topic in created PRIVATE room and invited, expect 403
self.helper.invite(
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
)
channel = self.make_request("PUT", topic_path, topic_content)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# get topic in created PRIVATE room and invited, expect 403
channel = self.make_request("GET", topic_path)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# set/get topic in created PRIVATE room and joined, expect 200
self.helper.join(room=self.created_rmid, user=self.user_id)
@@ -188,25 +195,25 @@ class RoomPermissionsTestCase(RoomBase):
# Only room ops can set topic by default
self.helper.auth_user_id = self.rmcreator_id
channel = self.make_request("PUT", topic_path, topic_content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.helper.auth_user_id = self.user_id
channel = self.make_request("GET", topic_path)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body)
# set/get topic in created PRIVATE room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id)
channel = self.make_request("PUT", topic_path, topic_content)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", topic_path)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# get topic in PUBLIC room, not joined, expect 403
channel = self.make_request(
"GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# set topic in PUBLIC room, not joined, expect 403
channel = self.make_request(
@@ -214,7 +221,7 @@ class RoomPermissionsTestCase(RoomBase):
"/rooms/%s/state/m.room.topic" % self.created_public_rmid,
topic_content,
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def _test_get_membership(
self, room: str, members: Iterable = frozenset(), expect_code: int = 200
@@ -303,14 +310,14 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=self.rmcreator_id,
membership=Membership.JOIN,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
self.helper.change_membership(
room=room,
src=self.user_id,
targ=self.rmcreator_id,
membership=Membership.LEAVE,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
def test_joined_permissions(self) -> None:
@@ -336,7 +343,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=other,
membership=Membership.JOIN,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
# set left of other, expect 403
@@ -345,7 +352,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=other,
membership=Membership.LEAVE,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
# set left of self, expect 200
@@ -365,7 +372,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=usr,
membership=Membership.INVITE,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
self.helper.change_membership(
@@ -373,7 +380,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=usr,
membership=Membership.JOIN,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
# It is always valid to LEAVE if you've already left (currently.)
@@ -382,7 +389,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=self.rmcreator_id,
membership=Membership.LEAVE,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
# tests the "from banned" line from the table in https://spec.matrix.org/unstable/client-server-api/#mroommember
@@ -399,7 +406,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=other,
membership=Membership.BAN,
- expect_code=403, # expect failure
+ expect_code=HTTPStatus.FORBIDDEN, # expect failure
expect_errcode=Codes.FORBIDDEN,
)
@@ -409,7 +416,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.rmcreator_id,
targ=other,
membership=Membership.BAN,
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
# from ban to invite: Must never happen.
@@ -418,7 +425,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.rmcreator_id,
targ=other,
membership=Membership.INVITE,
- expect_code=403, # expect failure
+ expect_code=HTTPStatus.FORBIDDEN, # expect failure
expect_errcode=Codes.BAD_STATE,
)
@@ -428,7 +435,7 @@ class RoomPermissionsTestCase(RoomBase):
src=other,
targ=other,
membership=Membership.JOIN,
- expect_code=403, # expect failure
+ expect_code=HTTPStatus.FORBIDDEN, # expect failure
expect_errcode=Codes.BAD_STATE,
)
@@ -438,7 +445,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.rmcreator_id,
targ=other,
membership=Membership.BAN,
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
# from ban to knock: Must never happen.
@@ -447,7 +454,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.rmcreator_id,
targ=other,
membership=Membership.KNOCK,
- expect_code=403, # expect failure
+ expect_code=HTTPStatus.FORBIDDEN, # expect failure
expect_errcode=Codes.BAD_STATE,
)
@@ -457,7 +464,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=other,
membership=Membership.LEAVE,
- expect_code=403, # expect failure
+ expect_code=HTTPStatus.FORBIDDEN, # expect failure
expect_errcode=Codes.FORBIDDEN,
)
@@ -467,10 +474,53 @@ class RoomPermissionsTestCase(RoomBase):
src=self.rmcreator_id,
targ=other,
membership=Membership.LEAVE,
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
+class RoomStateTestCase(RoomBase):
+ """Tests /rooms/$room_id/state."""
+
+ user_id = "@sid1:red"
+
+ def test_get_state_cancellation(self) -> None:
+ """Test cancellation of a `/rooms/$room_id/state` request."""
+ room_id = self.helper.create_room_as(self.user_id)
+ channel = make_request_with_cancellation_test(
+ "test_state_cancellation",
+ self.reactor,
+ self.site,
+ "GET",
+ "/rooms/%s/state" % room_id,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
+ self.assertCountEqual(
+ [state_event["type"] for state_event in channel.json_list],
+ {
+ "m.room.create",
+ "m.room.power_levels",
+ "m.room.join_rules",
+ "m.room.member",
+ "m.room.history_visibility",
+ },
+ )
+
+ def test_get_state_event_cancellation(self) -> None:
+ """Test cancellation of a `/rooms/$room_id/state/$event_type` request."""
+ room_id = self.helper.create_room_as(self.user_id)
+ channel = make_request_with_cancellation_test(
+ "test_state_cancellation",
+ self.reactor,
+ self.site,
+ "GET",
+ "/rooms/%s/state/m.room.member/%s" % (room_id, self.user_id),
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
+ self.assertEqual(channel.json_body, {"membership": "join"})
+
+
class RoomsMemberListTestCase(RoomBase):
"""Tests /rooms/$room_id/members/list REST events."""
@@ -481,16 +531,16 @@ class RoomsMemberListTestCase(RoomBase):
def test_get_member_list(self) -> None:
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
def test_get_member_list_no_room(self) -> None:
channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission(self) -> None:
room_id = self.helper.create_room_as("@some_other_guy:red")
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission_with_at_token(self) -> None:
"""
@@ -501,7 +551,7 @@ class RoomsMemberListTestCase(RoomBase):
# first sync to get an at token
channel = self.make_request("GET", "/sync")
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
sync_token = channel.json_body["next_batch"]
# check that permission is denied for @sid1:red to get the
@@ -510,7 +560,7 @@ class RoomsMemberListTestCase(RoomBase):
"GET",
f"/rooms/{room_id}/members?at={sync_token}",
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission_former_member(self) -> None:
"""
@@ -523,14 +573,14 @@ class RoomsMemberListTestCase(RoomBase):
# check that the user can see the member list to start with
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# ban the user
self.helper.change_membership(room_id, "@alice:red", self.user_id, "ban")
# check the user can no longer see the member list
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission_former_member_with_at_token(self) -> None:
"""
@@ -544,14 +594,14 @@ class RoomsMemberListTestCase(RoomBase):
# sync to get an at token
channel = self.make_request("GET", "/sync")
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
sync_token = channel.json_body["next_batch"]
# check that the user can see the member list to start with
channel = self.make_request(
"GET", "/rooms/%s/members?at=%s" % (room_id, sync_token)
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# ban the user (Note: the user is actually allowed to see this event and
# state so that they know they're banned!)
@@ -563,14 +613,14 @@ class RoomsMemberListTestCase(RoomBase):
# now, with the original user, sync again to get a new at token
channel = self.make_request("GET", "/sync")
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
sync_token = channel.json_body["next_batch"]
# check the user can no longer see the updated member list
channel = self.make_request(
"GET", "/rooms/%s/members?at=%s" % (room_id, sync_token)
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def test_get_member_list_mixed_memberships(self) -> None:
room_creator = "@some_other_guy:red"
@@ -579,17 +629,73 @@ class RoomsMemberListTestCase(RoomBase):
self.helper.invite(room=room_id, src=room_creator, targ=self.user_id)
# can't see list if you're just invited.
channel = self.make_request("GET", room_path)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
self.helper.join(room=room_id, user=self.user_id)
# can see list now joined
channel = self.make_request("GET", room_path)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.helper.leave(room=room_id, user=self.user_id)
# can see old list once left
channel = self.make_request("GET", room_path)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
+
+ def test_get_member_list_cancellation(self) -> None:
+ """Test cancellation of a `/rooms/$room_id/members` request."""
+ room_id = self.helper.create_room_as(self.user_id)
+ channel = make_request_with_cancellation_test(
+ "test_get_member_list_cancellation",
+ self.reactor,
+ self.site,
+ "GET",
+ "/rooms/%s/members" % room_id,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["chunk"]), 1)
+ self.assertLessEqual(
+ {
+ "content": {"membership": "join"},
+ "room_id": room_id,
+ "sender": self.user_id,
+ "state_key": self.user_id,
+ "type": "m.room.member",
+ "user_id": self.user_id,
+ }.items(),
+ channel.json_body["chunk"][0].items(),
+ )
+
+ def test_get_member_list_with_at_token_cancellation(self) -> None:
+ """Test cancellation of a `/rooms/$room_id/members?at=<sync token>` request."""
+ room_id = self.helper.create_room_as(self.user_id)
+
+ # first sync to get an at token
+ channel = self.make_request("GET", "/sync")
+ self.assertEqual(HTTPStatus.OK, channel.code)
+ sync_token = channel.json_body["next_batch"]
+
+ channel = make_request_with_cancellation_test(
+ "test_get_member_list_with_at_token_cancellation",
+ self.reactor,
+ self.site,
+ "GET",
+ "/rooms/%s/members?at=%s" % (room_id, sync_token),
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["chunk"]), 1)
+ self.assertLessEqual(
+ {
+ "content": {"membership": "join"},
+ "room_id": room_id,
+ "sender": self.user_id,
+ "state_key": self.user_id,
+ "type": "m.room.member",
+ "user_id": self.user_id,
+ }.items(),
+ channel.json_body["chunk"][0].items(),
+ )
class RoomsCreateTestCase(RoomBase):
@@ -601,19 +707,34 @@ class RoomsCreateTestCase(RoomBase):
# POST with no config keys, expect new room id
channel = self.make_request("POST", "/createRoom", "{}")
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
+ assert channel.resource_usage is not None
+ self.assertEqual(44, channel.resource_usage.db_txn_count)
+
+ def test_post_room_initial_state(self) -> None:
+ # POST with initial_state config key, expect new room id
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ b'{"initial_state":[{"type": "m.bridge", "content": {}}]}',
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ self.assertTrue("room_id" in channel.json_body)
+ assert channel.resource_usage is not None
+ self.assertEqual(50, channel.resource_usage.db_txn_count)
def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id
channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}')
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_custom_key(self) -> None:
# POST with custom config keys, expect new room id
channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}')
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_known_and_unknown_keys(self) -> None:
@@ -621,16 +742,16 @@ class RoomsCreateTestCase(RoomBase):
channel = self.make_request(
"POST", "/createRoom", b'{"visibility":"private","custom":"things"}'
)
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_invalid_content(self) -> None:
# POST with invalid content / paths, expect 400
channel = self.make_request("POST", "/createRoom", b'{"visibili')
- self.assertEqual(400, channel.code)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code)
channel = self.make_request("POST", "/createRoom", b'["hello"]')
- self.assertEqual(400, channel.code)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code)
def test_post_room_invitees_invalid_mxid(self) -> None:
# POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088
@@ -638,7 +759,7 @@ class RoomsCreateTestCase(RoomBase):
channel = self.make_request(
"POST", "/createRoom", b'{"invite":["@alice:example.com "]}'
)
- self.assertEqual(400, channel.code)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code)
@unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}})
def test_post_room_invitees_ratelimit(self) -> None:
@@ -649,20 +770,18 @@ class RoomsCreateTestCase(RoomBase):
# Build the request's content. We use local MXIDs because invites over federation
# are more difficult to mock.
- content = json.dumps(
- {
- "invite": [
- "@alice1:red",
- "@alice2:red",
- "@alice3:red",
- "@alice4:red",
- ]
- }
- ).encode("utf8")
+ content = {
+ "invite": [
+ "@alice1:red",
+ "@alice2:red",
+ "@alice3:red",
+ "@alice4:red",
+ ]
+ }
# Test that the invites are correctly ratelimited.
channel = self.make_request("POST", "/createRoom", content)
- self.assertEqual(400, channel.code)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code)
self.assertEqual(
"Cannot invite so many users at once",
channel.json_body["error"],
@@ -675,11 +794,13 @@ class RoomsCreateTestCase(RoomBase):
# Test that the invites aren't ratelimited anymore.
channel = self.make_request("POST", "/createRoom", content)
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
- def test_spam_checker_may_join_room(self) -> None:
+ def test_spam_checker_may_join_room_deprecated(self) -> None:
"""Tests that the user_may_join_room spam checker callback is correctly bypassed
when creating a new room.
+
+ In this test, we use the deprecated API in which callbacks return a bool.
"""
async def user_may_join_room(
@@ -697,10 +818,55 @@ class RoomsCreateTestCase(RoomBase):
"/createRoom",
{},
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
self.assertEqual(join_mock.call_count, 0)
+ def test_spam_checker_may_join_room(self) -> None:
+ """Tests that the user_may_join_room spam checker callback is correctly bypassed
+ when creating a new room.
+
+ In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`.
+ """
+
+ async def user_may_join_room_codes(
+ mxid: str,
+ room_id: str,
+ is_invite: bool,
+ ) -> Codes:
+ return Codes.CONSENT_NOT_GIVEN
+
+ join_mock = Mock(side_effect=user_may_join_room_codes)
+ self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock)
+
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ {},
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
+
+ self.assertEqual(join_mock.call_count, 0)
+
+ # Now change the return value of the callback to deny any join. Since we're
+ # creating the room, despite the return value, we should be able to join.
+ async def user_may_join_room_tuple(
+ mxid: str,
+ room_id: str,
+ is_invite: bool,
+ ) -> Tuple[Codes, dict]:
+ return Codes.INCOMPATIBLE_ROOM_VERSION, {}
+
+ join_mock.side_effect = user_may_join_room_tuple
+
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ {},
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
+ self.assertEqual(join_mock.call_count, 0)
+
class RoomTopicTestCase(RoomBase):
"""Tests /rooms/$room_id/topic REST events."""
@@ -715,54 +881,68 @@ class RoomTopicTestCase(RoomBase):
def test_invalid_puts(self) -> None:
# missing keys or invalid json
channel = self.make_request("PUT", self.path, "{}")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", self.path, '{"_name":"bo"}')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", self.path, '{"nao')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request(
"PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]'
)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", self.path, "text only")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", self.path, "")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
# valid key, wrong type
content = '{"topic":["Topic name"]}'
channel = self.make_request("PUT", self.path, content)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
def test_rooms_topic(self) -> None:
# nothing should be there
channel = self.make_request("GET", self.path)
- self.assertEqual(404, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.result["body"])
# valid put
content = '{"topic":"Topic name"}'
channel = self.make_request("PUT", self.path, content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# valid get
channel = self.make_request("GET", self.path)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body)
def test_rooms_topic_with_extra_keys(self) -> None:
# valid put with extra keys
content = '{"topic":"Seasons","subtopic":"Summer"}'
channel = self.make_request("PUT", self.path, content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# valid get
channel = self.make_request("GET", self.path)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body)
@@ -778,22 +958,34 @@ class RoomMemberStateTestCase(RoomBase):
path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
# missing keys or invalid json
channel = self.make_request("PUT", path, "{}")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, '{"_name":"bo"}')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, '{"nao')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, "text only")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, "")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
# valid keys, wrong types
content = '{"membership":["%s","%s","%s"]}' % (
@@ -802,7 +994,9 @@ class RoomMemberStateTestCase(RoomBase):
Membership.LEAVE,
)
channel = self.make_request("PUT", path, content.encode("ascii"))
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
def test_rooms_members_self(self) -> None:
path = "/rooms/%s/state/m.room.member/%s" % (
@@ -813,10 +1007,10 @@ class RoomMemberStateTestCase(RoomBase):
# valid join message (NOOP since we made the room)
content = '{"membership":"%s"}' % Membership.JOIN
channel = self.make_request("PUT", path, content.encode("ascii"))
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", path, content=b"")
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
expected_response = {"membership": Membership.JOIN}
self.assertEqual(expected_response, channel.json_body)
@@ -831,10 +1025,10 @@ class RoomMemberStateTestCase(RoomBase):
# valid invite message
content = '{"membership":"%s"}' % Membership.INVITE
channel = self.make_request("PUT", path, content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", path, content=b"")
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual(json.loads(content), channel.json_body)
def test_rooms_members_other_custom_keys(self) -> None:
@@ -850,10 +1044,10 @@ class RoomMemberStateTestCase(RoomBase):
"Join us!",
)
channel = self.make_request("PUT", path, content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", path, content=b"")
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual(json.loads(content), channel.json_body)
@@ -911,9 +1105,11 @@ class RoomJoinTestCase(RoomBase):
self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
- def test_spam_checker_may_join_room(self) -> None:
+ def test_spam_checker_may_join_room_deprecated(self) -> None:
"""Tests that the user_may_join_room spam checker callback is correctly called
and blocks room joins when needed.
+
+ This test uses the deprecated API, in which callbacks return booleans.
"""
# Register a dummy callback. Make it allow all room joins for now.
@@ -926,6 +1122,8 @@ class RoomJoinTestCase(RoomBase):
) -> bool:
return return_value
+ # `spec` argument is needed for this function mock to have `__qualname__`, which
+ # is needed for `Measure` metrics buried in SpamChecker.
callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None)
self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
@@ -966,7 +1164,92 @@ class RoomJoinTestCase(RoomBase):
# Now make the callback deny all room joins, and check that a join actually fails.
return_value = False
- self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2)
+ self.helper.join(
+ self.room3, self.user2, expect_code=HTTPStatus.FORBIDDEN, tok=self.tok2
+ )
+
+ def test_spam_checker_may_join_room(self) -> None:
+ """Tests that the user_may_join_room spam checker callback is correctly called
+ and blocks room joins when needed.
+
+ This test uses the latest API to this day, in which callbacks return `NOT_SPAM` or `Codes`.
+ """
+
+ # Register a dummy callback. Make it allow all room joins for now.
+ return_value: Union[
+ Literal["NOT_SPAM"], Tuple[Codes, dict], Codes
+ ] = synapse.module_api.NOT_SPAM
+
+ async def user_may_join_room(
+ userid: str,
+ room_id: str,
+ is_invited: bool,
+ ) -> Union[Literal["NOT_SPAM"], Tuple[Codes, dict], Codes]:
+ return return_value
+
+ # `spec` argument is needed for this function mock to have `__qualname__`, which
+ # is needed for `Measure` metrics buried in SpamChecker.
+ callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None)
+ self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
+
+ # Join a first room, without being invited to it.
+ self.helper.join(self.room1, self.user2, tok=self.tok2)
+
+ # Check that the callback was called with the right arguments.
+ expected_call_args = (
+ (
+ self.user2,
+ self.room1,
+ False,
+ ),
+ )
+ self.assertEqual(
+ callback_mock.call_args,
+ expected_call_args,
+ callback_mock.call_args,
+ )
+
+ # Join a second room, this time with an invite for it.
+ self.helper.invite(self.room2, self.user1, self.user2, tok=self.tok1)
+ self.helper.join(self.room2, self.user2, tok=self.tok2)
+
+ # Check that the callback was called with the right arguments.
+ expected_call_args = (
+ (
+ self.user2,
+ self.room2,
+ True,
+ ),
+ )
+ self.assertEqual(
+ callback_mock.call_args,
+ expected_call_args,
+ callback_mock.call_args,
+ )
+
+ # Now make the callback deny all room joins, and check that a join actually fails.
+ # We pick an arbitrary Codes rather than the default `Codes.FORBIDDEN`.
+ return_value = Codes.CONSENT_NOT_GIVEN
+ self.helper.invite(self.room3, self.user1, self.user2, tok=self.tok1)
+ self.helper.join(
+ self.room3,
+ self.user2,
+ expect_code=HTTPStatus.FORBIDDEN,
+ expect_errcode=return_value,
+ tok=self.tok2,
+ )
+
+ # Now make the callback deny all room joins, and check that a join actually fails.
+ # As above, with the experimental extension that lets us return dictionaries.
+ return_value = (Codes.BAD_ALIAS, {"another_field": "12345"})
+ self.helper.join(
+ self.room3,
+ self.user2,
+ expect_code=HTTPStatus.FORBIDDEN,
+ expect_errcode=return_value[0],
+ tok=self.tok2,
+ expect_additional_fields=return_value[1],
+ )
class RoomJoinRatelimitTestCase(RoomBase):
@@ -1016,7 +1299,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
# Update the display name for the user.
path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
channel = self.make_request("PUT", path, {"displayname": "John Doe"})
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# Check that all the rooms have been sent a profile update into.
for room_id in room_ids:
@@ -1081,40 +1364,153 @@ class RoomMessagesTestCase(RoomBase):
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
# missing keys or invalid json
channel = self.make_request("PUT", path, b"{}")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, b'{"_name":"bo"}')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, b'{"nao')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, b"text only")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, b"")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
def test_rooms_messages_sent(self) -> None:
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
content = b'{"body":"test","msgtype":{"type":"a"}}'
channel = self.make_request("PUT", path, content)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
# custom message types
content = b'{"body":"test","msgtype":"test.custom.text"}'
channel = self.make_request("PUT", path, content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# m.text message type
path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id))
content = b'{"body":"test2","msgtype":"m.text"}'
channel = self.make_request("PUT", path, content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
+
+ @parameterized.expand(
+ [
+ # Allow
+ param(
+ name="NOT_SPAM",
+ value="NOT_SPAM",
+ expected_code=HTTPStatus.OK,
+ expected_fields={},
+ ),
+ param(
+ name="False",
+ value=False,
+ expected_code=HTTPStatus.OK,
+ expected_fields={},
+ ),
+ # Block
+ param(
+ name="scalene string",
+ value="ANY OTHER STRING",
+ expected_code=HTTPStatus.FORBIDDEN,
+ expected_fields={"errcode": "M_FORBIDDEN"},
+ ),
+ param(
+ name="True",
+ value=True,
+ expected_code=HTTPStatus.FORBIDDEN,
+ expected_fields={"errcode": "M_FORBIDDEN"},
+ ),
+ param(
+ name="Code",
+ value=Codes.LIMIT_EXCEEDED,
+ expected_code=HTTPStatus.FORBIDDEN,
+ expected_fields={"errcode": "M_LIMIT_EXCEEDED"},
+ ),
+ param(
+ name="Tuple",
+ value=(Codes.SERVER_NOT_TRUSTED, {"additional_field": "12345"}),
+ expected_code=HTTPStatus.FORBIDDEN,
+ expected_fields={
+ "errcode": "M_SERVER_NOT_TRUSTED",
+ "additional_field": "12345",
+ },
+ ),
+ ]
+ )
+ def test_spam_checker_check_event_for_spam(
+ self,
+ name: str,
+ value: Union[str, bool, Codes, Tuple[Codes, JsonDict]],
+ expected_code: int,
+ expected_fields: dict,
+ ) -> None:
+ class SpamCheck:
+ mock_return_value: Union[
+ str, bool, Codes, Tuple[Codes, JsonDict], bool
+ ] = "NOT_SPAM"
+ mock_content: Optional[JsonDict] = None
+
+ async def check_event_for_spam(
+ self,
+ event: synapse.events.EventBase,
+ ) -> Union[str, Codes, Tuple[Codes, JsonDict], bool]:
+ self.mock_content = event.content
+ return self.mock_return_value
+
+ spam_checker = SpamCheck()
+
+ self.hs.get_spam_checker()._check_event_for_spam_callbacks.append(
+ spam_checker.check_event_for_spam
+ )
+
+ # Inject `value` as mock_return_value
+ spam_checker.mock_return_value = value
+ path = "/rooms/%s/send/m.room.message/check_event_for_spam_%s" % (
+ urlparse.quote(self.room_id),
+ urlparse.quote(name),
+ )
+ body = "test-%s" % name
+ content = '{"body":"%s","msgtype":"m.text"}' % body
+ channel = self.make_request("PUT", path, content)
+
+ # Check that the callback has witnessed the correct event.
+ self.assertIsNotNone(spam_checker.mock_content)
+ if (
+ spam_checker.mock_content is not None
+ ): # Checked just above, but mypy doesn't know about that.
+ self.assertEqual(
+ spam_checker.mock_content["body"], body, spam_checker.mock_content
+ )
+
+ # Check that we have the correct result.
+ self.assertEqual(expected_code, channel.code, msg=channel.result["body"])
+ for expected_key, expected_value in expected_fields.items():
+ self.assertEqual(
+ channel.json_body.get(expected_key, None),
+ expected_value,
+ "Field %s absent or invalid " % expected_key,
+ )
class RoomPowerLevelOverridesTestCase(RoomBase):
@@ -1239,7 +1635,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
channel = self.make_request("PUT", path, "{}")
# Then I am allowed
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
def test_normal_user_can_not_post_state_event(self) -> None:
# Given I am a normal member of a room
@@ -1253,7 +1649,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
channel = self.make_request("PUT", path, "{}")
# Then I am not allowed because state events require PL>=50
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
self.assertEqual(
"You don't have permission to post that to the room. "
"user_level (0) < send_level (50)",
@@ -1280,7 +1676,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
channel = self.make_request("PUT", path, "{}")
# Then I am allowed
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
@unittest.override_config(
{
@@ -1308,7 +1704,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
channel = self.make_request("PUT", path, "{}")
# Then I am not allowed
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
@unittest.override_config(
{
@@ -1336,7 +1732,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
channel = self.make_request("PUT", path, "{}")
# Then I am not allowed
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
self.assertEqual(
"You don't have permission to post that to the room. "
+ "user_level (0) < send_level (1)",
@@ -1367,7 +1763,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
# Then I am not allowed because the public_chat config does not
# affect this room, because this room is a private_chat
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
self.assertEqual(
"You don't have permission to post that to the room. "
+ "user_level (0) < send_level (50)",
@@ -1386,7 +1782,7 @@ class RoomInitialSyncTestCase(RoomBase):
def test_initial_sync(self) -> None:
channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id)
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
self.assertEqual(self.room_id, channel.json_body["room_id"])
self.assertEqual("join", channel.json_body["membership"])
@@ -1429,7 +1825,7 @@ class RoomMessageListTestCase(RoomBase):
channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
self.assertTrue("start" in channel.json_body)
self.assertEqual(token, channel.json_body["start"])
self.assertTrue("chunk" in channel.json_body)
@@ -1440,7 +1836,7 @@ class RoomMessageListTestCase(RoomBase):
channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
self.assertTrue("start" in channel.json_body)
self.assertEqual(token, channel.json_body["start"])
self.assertTrue("chunk" in channel.json_body)
@@ -1479,7 +1875,7 @@ class RoomMessageListTestCase(RoomBase):
json.dumps({"types": [EventTypes.Message]}),
),
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
chunk = channel.json_body["chunk"]
self.assertEqual(len(chunk), 2, [event["content"] for event in chunk])
@@ -1507,7 +1903,7 @@ class RoomMessageListTestCase(RoomBase):
json.dumps({"types": [EventTypes.Message]}),
),
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
chunk = channel.json_body["chunk"]
self.assertEqual(len(chunk), 1, [event["content"] for event in chunk])
@@ -1524,7 +1920,7 @@ class RoomMessageListTestCase(RoomBase):
json.dumps({"types": [EventTypes.Message]}),
),
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
chunk = channel.json_body["chunk"]
self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])
@@ -1652,14 +2048,97 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
def test_restricted_no_auth(self) -> None:
channel = self.make_request("GET", self.url)
- self.assertEqual(channel.code, 401, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
def test_restricted_auth(self) -> None:
self.register_user("user", "pass")
tok = self.login("user", "pass")
channel = self.make_request("GET", self.url, access_token=tok)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+
+class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+
+ config = self.default_config()
+ config["allow_public_rooms_without_auth"] = True
+ self.hs = self.setup_test_homeserver(config=config)
+ self.url = b"/_matrix/client/r0/publicRooms"
+
+ return self.hs
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ user = self.register_user("alice", "pass")
+ self.token = self.login(user, "pass")
+
+ # Create a room
+ self.helper.create_room_as(
+ user,
+ is_public=True,
+ extra_content={"visibility": "public"},
+ tok=self.token,
+ )
+ # Create a space
+ self.helper.create_room_as(
+ user,
+ is_public=True,
+ extra_content={
+ "visibility": "public",
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE},
+ },
+ tok=self.token,
+ )
+
+ def make_public_rooms_request(
+ self, room_types: Union[List[Union[str, None]], None]
+ ) -> Tuple[List[Dict[str, Any]], int]:
+ channel = self.make_request(
+ "POST",
+ self.url,
+ {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}},
+ self.token,
+ )
+ chunk = channel.json_body["chunk"]
+ count = channel.json_body["total_room_count_estimate"]
+
+ self.assertEqual(len(chunk), count)
+
+ return chunk, count
+
+ def test_returns_both_rooms_and_spaces_if_no_filter(self) -> None:
+ chunk, count = self.make_public_rooms_request(None)
+
+ self.assertEqual(count, 2)
+
+ def test_returns_only_rooms_based_on_filter(self) -> None:
+ chunk, count = self.make_public_rooms_request([None])
+
+ self.assertEqual(count, 1)
+ self.assertEqual(chunk[0].get("room_type", None), None)
+
+ def test_returns_only_space_based_on_filter(self) -> None:
+ chunk, count = self.make_public_rooms_request(["m.space"])
+
+ self.assertEqual(count, 1)
+ self.assertEqual(chunk[0].get("room_type", None), "m.space")
+
+ def test_returns_both_rooms_and_space_based_on_filter(self) -> None:
+ chunk, count = self.make_public_rooms_request(["m.space", None])
+
+ self.assertEqual(count, 2)
+
+ def test_returns_both_rooms_and_spaces_if_array_is_empty(self) -> None:
+ chunk, count = self.make_public_rooms_request([])
+
+ self.assertEqual(count, 2)
class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
@@ -1686,7 +2165,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
"Simple test for searching rooms over federation"
self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined]
- search_filter = {"generic_search_term": "foobar"}
+ search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"}
channel = self.make_request(
"POST",
@@ -1694,7 +2173,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
content={"filter": search_filter},
access_token=self.token,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.federation_client.get_public_rooms.assert_called_once_with( # type: ignore[attr-defined]
"testserv",
@@ -1711,11 +2190,11 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
# The `get_public_rooms` should be called again if the first call fails
# with a 404, when using search filters.
self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined]
- HttpResponseException(404, "Not Found", b""),
+ HttpResponseException(HTTPStatus.NOT_FOUND, "Not Found", b""),
make_awaitable({}),
)
- search_filter = {"generic_search_term": "foobar"}
+ search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"}
channel = self.make_request(
"POST",
@@ -1723,7 +2202,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
content={"filter": search_filter},
access_token=self.token,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.federation_client.get_public_rooms.assert_has_calls( # type: ignore[attr-defined]
[
@@ -1769,21 +2248,19 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
# Set a profile for the test user
self.displayname = "test user"
- data = {"displayname": self.displayname}
- request_data = json.dumps(data)
+ request_data = {"displayname": self.displayname}
channel = self.make_request(
"PUT",
"/_matrix/client/r0/profile/%s/displayname" % (self.user_id,),
request_data,
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
def test_per_room_profile_forbidden(self) -> None:
- data = {"membership": "join", "displayname": "other test user"}
- request_data = json.dumps(data)
+ request_data = {"membership": "join", "displayname": "other test user"}
channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
@@ -1791,7 +2268,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
request_data,
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
event_id = channel.json_body["event_id"]
channel = self.make_request(
@@ -1799,7 +2276,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
res_displayname = channel.json_body["content"]["displayname"]
self.assertEqual(res_displayname, self.displayname, channel.result)
@@ -1833,7 +2310,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason},
access_token=self.second_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -1847,7 +2324,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason},
access_token=self.second_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -1861,7 +2338,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.second_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -1875,7 +2352,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -1887,7 +2364,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -1899,7 +2376,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -1918,7 +2395,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason},
access_token=self.second_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -1930,7 +2407,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
),
access_token=self.creator_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
event_content = channel.json_body
@@ -1978,7 +2455,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id, json.dumps(self.FILTER_LABELS)),
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
events_before = channel.json_body["events_before"]
@@ -2008,7 +2485,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)),
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
events_before = channel.json_body["events_before"]
@@ -2043,7 +2520,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)),
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
events_before = channel.json_body["events_before"]
@@ -2123,16 +2600,14 @@ class LabelsTestCase(unittest.HomeserverTestCase):
def test_search_filter_labels(self) -> None:
"""Test that we can filter by a label on a /search request."""
- request_data = json.dumps(
- {
- "search_categories": {
- "room_events": {
- "search_term": "label",
- "filter": self.FILTER_LABELS,
- }
+ request_data = {
+ "search_categories": {
+ "room_events": {
+ "search_term": "label",
+ "filter": self.FILTER_LABELS,
}
}
- )
+ }
self._send_labelled_messages_in_room()
@@ -2160,16 +2635,14 @@ class LabelsTestCase(unittest.HomeserverTestCase):
def test_search_filter_not_labels(self) -> None:
"""Test that we can filter by the absence of a label on a /search request."""
- request_data = json.dumps(
- {
- "search_categories": {
- "room_events": {
- "search_term": "label",
- "filter": self.FILTER_NOT_LABELS,
- }
+ request_data = {
+ "search_categories": {
+ "room_events": {
+ "search_term": "label",
+ "filter": self.FILTER_NOT_LABELS,
}
}
- )
+ }
self._send_labelled_messages_in_room()
@@ -2209,16 +2682,14 @@ class LabelsTestCase(unittest.HomeserverTestCase):
"""Test that we can filter by both a label and the absence of another label on a
/search request.
"""
- request_data = json.dumps(
- {
- "search_categories": {
- "room_events": {
- "search_term": "label",
- "filter": self.FILTER_LABELS_NOT_LABELS,
- }
+ request_data = {
+ "search_categories": {
+ "room_events": {
+ "search_term": "label",
+ "filter": self.FILTER_LABELS_NOT_LABELS,
}
}
- )
+ }
self._send_labelled_messages_in_room()
@@ -2391,7 +2862,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)),
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
return channel.json_body["chunk"]
@@ -2496,7 +2967,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id),
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
events_before = channel.json_body["events_before"]
@@ -2562,7 +3033,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id),
access_token=invited_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
events_before = channel.json_body["events_before"]
@@ -2663,8 +3134,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None:
url = "/_matrix/client/r0/directory/room/" + alias
- data = {"room_id": self.room_id}
- request_data = json.dumps(data)
+ request_data = {"room_id": self.room_id}
channel = self.make_request(
"PUT", url, request_data, access_token=self.room_owner_tok
@@ -2693,8 +3163,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None:
url = "/_matrix/client/r0/directory/room/" + alias
- data = {"room_id": self.room_id}
- request_data = json.dumps(data)
+ request_data = {"room_id": self.room_id}
channel = self.make_request(
"PUT", url, request_data, access_token=self.room_owner_tok
@@ -2720,7 +3189,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
"rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
- json.dumps(content),
+ content,
access_token=self.room_owner_tok,
)
self.assertEqual(channel.code, expected_code, channel.result)
@@ -2845,11 +3314,16 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
- def test_threepid_invite_spamcheck(self) -> None:
+ def test_threepid_invite_spamcheck_deprecated(self) -> None:
+ """
+ Test allowing/blocking threepid invites with a spam-check module.
+
+ In this test, we use the deprecated API in which callbacks return a bool.
+ """
# Mock a few functions to prevent the test from failing due to failing to talk to
- # a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we
+ # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test.
- make_invite_mock = Mock(return_value=make_awaitable(0))
+ make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
self.hs.get_identity_handler().lookup_3pid = Mock(
return_value=make_awaitable(None),
@@ -2901,3 +3375,107 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Also check that it stopped before calling _make_and_store_3pid_invite.
make_invite_mock.assert_called_once()
+
+ def test_threepid_invite_spamcheck(self) -> None:
+ """
+ Test allowing/blocking threepid invites with a spam-check module.
+
+ In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`."""
+ # Mock a few functions to prevent the test from failing due to failing to talk to
+ # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
+ # can check its call_count later on during the test.
+ make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
+ self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
+ self.hs.get_identity_handler().lookup_3pid = Mock(
+ return_value=make_awaitable(None),
+ )
+
+ # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
+ # allow everything for now.
+ # `spec` argument is needed for this function mock to have `__qualname__`, which
+ # is needed for `Measure` metrics buried in SpamChecker.
+ mock = Mock(
+ return_value=make_awaitable(synapse.module_api.NOT_SPAM),
+ spec=lambda *x: None,
+ )
+ self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock)
+
+ # Send a 3PID invite into the room and check that it succeeded.
+ email_to_invite = "teresa@example.com"
+ channel = self.make_request(
+ method="POST",
+ path="/rooms/" + self.room_id + "/invite",
+ content={
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": email_to_invite,
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ # Check that the callback was called with the right params.
+ mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id)
+
+ # Check that the call to send the invite was made.
+ make_invite_mock.assert_called_once()
+
+ # Now change the return value of the callback to deny any invite and test that
+ # we can't send the invite. We pick an arbitrary error code to be able to check
+ # that the same code has been returned
+ mock.return_value = make_awaitable(Codes.CONSENT_NOT_GIVEN)
+ channel = self.make_request(
+ method="POST",
+ path="/rooms/" + self.room_id + "/invite",
+ content={
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": email_to_invite,
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(channel.json_body["errcode"], Codes.CONSENT_NOT_GIVEN)
+
+ # Also check that it stopped before calling _make_and_store_3pid_invite.
+ make_invite_mock.assert_called_once()
+
+ # Run variant with `Tuple[Codes, dict]`.
+ mock.return_value = make_awaitable((Codes.EXPIRED_ACCOUNT, {"field": "value"}))
+ channel = self.make_request(
+ method="POST",
+ path="/rooms/" + self.room_id + "/invite",
+ content={
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": email_to_invite,
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT)
+ self.assertEqual(channel.json_body["field"], "value")
+
+ # Also check that it stopped before calling _make_and_store_3pid_invite.
+ make_invite_mock.assert_called_once()
+
+ def test_400_missing_param_without_id_access_token(self) -> None:
+ """
+ Test that a 3pid invite request returns 400 M_MISSING_PARAM
+ if we do not include id_access_token.
+ """
+ channel = self.make_request(
+ method="POST",
+ path="/rooms/" + self.room_id + "/invite",
+ content={
+ "id_server": "example.com",
+ "medium": "email",
+ "address": "teresa@example.com",
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 400)
+ self.assertEqual(channel.json_body["errcode"], "M_MISSING_PARAM")
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index d9bd8c4a28..c807a37bc2 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -26,7 +26,7 @@ from synapse.rest.client import (
room_upgrade_rest_servlet,
)
from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
from synapse.util import Clock
from tests import unittest
@@ -97,7 +97,12 @@ class RoomTestCase(_ShadowBannedBase):
channel = self.make_request(
"POST",
"/rooms/%s/invite" % (room_id,),
- {"id_server": "test", "medium": "email", "address": "test@test.test"},
+ {
+ "id_server": "test",
+ "medium": "email",
+ "address": "test@test.test",
+ "id_access_token": "anytoken",
+ },
access_token=self.banned_access_token,
)
self.assertEqual(200, channel.code, channel.result)
@@ -275,7 +280,7 @@ class ProfileTestCase(_ShadowBannedBase):
message_handler = self.hs.get_message_handler()
event = self.get_success(
message_handler.get_room_data(
- self.banned_user_id,
+ create_requester(self.banned_user_id),
room_id,
"m.room.member",
self.banned_user_id,
@@ -310,7 +315,7 @@ class ProfileTestCase(_ShadowBannedBase):
message_handler = self.hs.get_message_handler()
event = self.get_success(
message_handler.get_room_data(
- self.banned_user_id,
+ create_requester(self.banned_user_id),
room_id,
"m.room.member",
self.banned_user_id,
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index e3efd1f1b0..0af643ecd9 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -38,7 +38,6 @@ from tests.federation.transport.test_knocking import (
KnockingStrippedStateEventHelperMixin,
)
from tests.server import TimedOutException
-from tests.unittest import override_config
class FilterTestCase(unittest.HomeserverTestCase):
@@ -390,6 +389,11 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
sync.register_servlets,
]
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ config = self.default_config()
+
+ return self.setup_test_homeserver(config=config)
+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.url = "/sync?since=%s"
self.next_batch = "s0"
@@ -408,7 +412,6 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Join the second user
self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
- @override_config({"experimental_features": {"msc2285_enabled": True}})
def test_private_read_receipts(self) -> None:
# Send a message as the first user
res = self.helper.send(self.room_id, body="hello", tok=self.tok)
@@ -416,7 +419,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Send a private read receipt to tell the server the first user's message was read
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
{},
access_token=self.tok2,
)
@@ -425,7 +428,6 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Test that the first user can't see the other user's private read receipt
self.assertIsNone(self._get_read_receipt())
- @override_config({"experimental_features": {"msc2285_enabled": True}})
def test_public_receipt_can_override_private(self) -> None:
"""
Sending a public read receipt to the same event which has a private read
@@ -456,7 +458,6 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Test that we did override the private read receipt
self.assertNotEqual(self._get_read_receipt(), None)
- @override_config({"experimental_features": {"msc2285_enabled": True}})
def test_private_receipt_cannot_override_public(self) -> None:
"""
Sending a private read receipt to the same event which has a public read
@@ -543,7 +544,6 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
config = super().default_config()
config["experimental_features"] = {
"msc2654_enabled": True,
- "msc2285_enabled": True,
}
return config
@@ -606,11 +606,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
self._check_unread_count(1)
# Send a read receipt to tell the server we've read the latest event.
- body = json.dumps({ReceiptTypes.READ: res["event_id"]}).encode("utf8")
channel = self.make_request(
"POST",
f"/rooms/{self.room_id}/read_markers",
- body,
+ {ReceiptTypes.READ: res["event_id"]},
access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -625,7 +624,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
# Send a read receipt to tell the server we've read the latest event.
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
{},
access_token=self.tok,
)
@@ -701,7 +700,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
self._check_unread_count(5)
res2 = self.helper.send(self.room_id, "hello", tok=self.tok2)
- # Make sure both m.read and org.matrix.msc2285.read.private advance
+ # Make sure both m.read and m.read.private advance
channel = self.make_request(
"POST",
f"/rooms/{self.room_id}/receipt/m.read/{res1['event_id']}",
@@ -713,16 +712,21 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res2['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res2['event_id']}",
{},
access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
self._check_unread_count(0)
- # We test for both receipt types that influence notification counts
- @parameterized.expand([ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE])
- def test_read_receipts_only_go_down(self, receipt_type: ReceiptTypes) -> None:
+ # We test for all three receipt types that influence notification counts
+ @parameterized.expand(
+ [
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ]
+ )
+ def test_read_receipts_only_go_down(self, receipt_type: str) -> None:
# Join the new user
self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
@@ -733,18 +737,18 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
# Read last event
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/{receipt_type}/{res2['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res2['event_id']}",
{},
access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
self._check_unread_count(0)
- # Make sure neither m.read nor org.matrix.msc2285.read.private make the
+ # Make sure neither m.read nor m.read.private make the
# read receipt go up to an older event
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res1['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res1['event_id']}",
{},
access_token=self.tok,
)
@@ -949,3 +953,24 @@ class ExcludeRoomTestCase(unittest.HomeserverTestCase):
self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["invite"])
self.assertIn(self.included_room_id, channel.json_body["rooms"]["invite"])
+
+ def test_incremental_sync(self) -> None:
+ """Tests that activity in the room is properly filtered out of incremental
+ syncs.
+ """
+ channel = self.make_request("GET", "/sync", access_token=self.tok)
+ self.assertEqual(channel.code, 200, channel.result)
+ next_batch = channel.json_body["next_batch"]
+
+ self.helper.send(self.excluded_room_id, tok=self.tok)
+ self.helper.send(self.included_room_id, tok=self.tok)
+
+ channel = self.make_request(
+ "GET",
+ f"/sync?since={next_batch}",
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["join"])
+ self.assertIn(self.included_room_id, channel.json_body["rooms"]["join"])
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 5eb0f243f7..3325d43a2f 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -20,8 +20,8 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, LoginType, Membership
from synapse.api.errors import SynapseError
from synapse.api.room_versions import RoomVersion
+from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.rest import admin
from synapse.rest.client import account, login, profile, room
@@ -113,14 +113,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Have this homeserver skip event auth checks. This is necessary due to
# event auth checks ensuring that events were signed by the sender's homeserver.
- async def _check_event_auth(
- origin: str,
- event: EventBase,
- context: EventContext,
- *args: Any,
- **kwargs: Any,
- ) -> EventContext:
- return context
+ async def _check_event_auth(origin: Any, event: Any, context: Any) -> None:
+ pass
hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment]
@@ -161,7 +155,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
{},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
callback.assert_called_once()
@@ -179,7 +173,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
{},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, channel.result)
def test_third_party_rules_workaround_synapse_errors_pass_through(self) -> None:
"""
@@ -192,12 +186,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"""
class NastyHackException(SynapseError):
- def error_dict(self) -> JsonDict:
+ def error_dict(self, config: Optional[HomeServerConfig]) -> JsonDict:
"""
This overrides SynapseError's `error_dict` to nastily inject
JSON into the error response.
"""
- result = super().error_dict()
+ result = super().error_dict(config)
result["nasty"] = "very"
return result
@@ -217,7 +211,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
access_token=self.tok,
)
# Check the error code
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, channel.result)
# Check the JSON body has had the `nasty` key injected
self.assertEqual(
channel.json_body,
@@ -266,7 +260,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
{"x": "x"},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
event_id = channel.json_body["event_id"]
# ... and check that it got modified
@@ -275,7 +269,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
ev = channel.json_body
self.assertEqual(ev["content"]["x"], "y")
@@ -304,7 +298,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
orig_event_id = channel.json_body["event_id"]
channel = self.make_request(
@@ -321,7 +315,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
edited_event_id = channel.json_body["event_id"]
# ... and check that they both got modified
@@ -330,7 +324,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, orig_event_id),
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
ev = channel.json_body
self.assertEqual(ev["content"]["body"], "ORIGINAL BODY")
@@ -339,7 +333,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, edited_event_id),
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
ev = channel.json_body
self.assertEqual(ev["content"]["body"], "EDITED BODY")
@@ -385,7 +379,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
event_id = channel.json_body["event_id"]
@@ -394,7 +388,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
self.assertIn("foo", channel.json_body["content"].keys())
self.assertEqual(channel.json_body["content"]["foo"], "bar")
diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index 98c1039d33..5e7bf97482 100644
--- a/tests/rest/client/test_upgrade_room.py
+++ b/tests/rest/client/test_upgrade_room.py
@@ -48,10 +48,14 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
self.helper.join(self.room_id, self.other, tok=self.other_token)
def _upgrade_room(
- self, token: Optional[str] = None, room_id: Optional[str] = None
+ self,
+ token: Optional[str] = None,
+ room_id: Optional[str] = None,
+ expire_cache: bool = True,
) -> FakeChannel:
- # We never want a cached response.
- self.reactor.advance(5 * 60 + 1)
+ if expire_cache:
+ # We don't want a cached response.
+ self.reactor.advance(5 * 60 + 1)
if room_id is None:
room_id = self.room_id
@@ -72,9 +76,24 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, channel.result)
self.assertIn("replacement_room", channel.json_body)
- def test_not_in_room(self) -> None:
+ new_room_id = channel.json_body["replacement_room"]
+
+ # Check that the tombstone event points to the new room.
+ tombstone_event = self.get_success(
+ self.hs.get_storage_controllers().state.get_current_state_event(
+ self.room_id, EventTypes.Tombstone, ""
+ )
+ )
+ self.assertIsNotNone(tombstone_event)
+ self.assertEqual(new_room_id, tombstone_event.content["replacement_room"])
+
+ # Check that the new room exists.
+ room = self.get_success(self.store.get_room(new_room_id))
+ self.assertIsNotNone(room)
+
+ def test_never_in_room(self) -> None:
"""
- Upgrading a room should work fine.
+ A user who has never been in the room cannot upgrade the room.
"""
# The user isn't in the room.
roomless = self.register_user("roomless", "pass")
@@ -83,6 +102,16 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
channel = self._upgrade_room(roomless_token)
self.assertEqual(403, channel.code, channel.result)
+ def test_left_room(self) -> None:
+ """
+ A user who is no longer in the room cannot upgrade the room.
+ """
+ # Remove the user from the room.
+ self.helper.leave(self.room_id, self.creator, tok=self.creator_token)
+
+ channel = self._upgrade_room(self.creator_token)
+ self.assertEqual(403, channel.code, channel.result)
+
def test_power_levels(self) -> None:
"""
Another user can upgrade the room if their power level is increased.
@@ -297,3 +326,47 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
self.assertEqual(
create_event.content.get(EventContentFields.ROOM_TYPE), test_room_type
)
+
+ def test_second_upgrade_from_same_user(self) -> None:
+ """A second room upgrade from the same user is deduplicated."""
+ channel1 = self._upgrade_room()
+ self.assertEqual(200, channel1.code, channel1.result)
+
+ channel2 = self._upgrade_room(expire_cache=False)
+ self.assertEqual(200, channel2.code, channel2.result)
+
+ self.assertEqual(
+ channel1.json_body["replacement_room"],
+ channel2.json_body["replacement_room"],
+ )
+
+ def test_second_upgrade_after_delay(self) -> None:
+ """A second room upgrade is not deduplicated after some time has passed."""
+ channel1 = self._upgrade_room()
+ self.assertEqual(200, channel1.code, channel1.result)
+
+ channel2 = self._upgrade_room(expire_cache=True)
+ self.assertEqual(200, channel2.code, channel2.result)
+
+ self.assertNotEqual(
+ channel1.json_body["replacement_room"],
+ channel2.json_body["replacement_room"],
+ )
+
+ def test_second_upgrade_from_different_user(self) -> None:
+ """A second room upgrade from a different user is blocked."""
+ channel = self._upgrade_room()
+ self.assertEqual(200, channel.code, channel.result)
+
+ channel = self._upgrade_room(self.other_token, expire_cache=False)
+ self.assertEqual(400, channel.code, channel.result)
+
+ def test_first_upgrade_does_not_block_second(self) -> None:
+ """A second room upgrade is not blocked when a previous upgrade attempt was not
+ allowed.
+ """
+ channel = self._upgrade_room(self.other_token)
+ self.assertEqual(403, channel.code, channel.result)
+
+ channel = self._upgrade_room(expire_cache=False)
+ self.assertEqual(200, channel.code, channel.result)
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index a0788b1bb0..dd26145bf8 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -41,6 +41,7 @@ from twisted.web.resource import Resource
from twisted.web.server import Site
from synapse.api.constants import Membership
+from synapse.api.errors import Codes
from synapse.server import HomeServer
from synapse.types import JsonDict
@@ -135,11 +136,11 @@ class RestHelper:
self.site,
"POST",
path,
- json.dumps(content).encode("utf8"),
+ content,
custom_headers=custom_headers,
)
- assert channel.result["code"] == b"%d" % expect_code, channel.result
+ assert channel.code == expect_code, channel.result
self.auth_user_id = temp_id
if expect_code == HTTPStatus.OK:
@@ -171,6 +172,8 @@ class RestHelper:
expect_code: int = HTTPStatus.OK,
tok: Optional[str] = None,
appservice_user_id: Optional[str] = None,
+ expect_errcode: Optional[Codes] = None,
+ expect_additional_fields: Optional[dict] = None,
) -> None:
self.change_membership(
room=room,
@@ -180,6 +183,8 @@ class RestHelper:
appservice_user_id=appservice_user_id,
membership=Membership.JOIN,
expect_code=expect_code,
+ expect_errcode=expect_errcode,
+ expect_additional_fields=expect_additional_fields,
)
def knock(
@@ -205,14 +210,12 @@ class RestHelper:
self.site,
"POST",
path,
- json.dumps(data).encode("utf8"),
+ data,
)
- assert (
- int(channel.result["code"]) == expect_code
- ), "Expected: %d, got: %d, resp: %r" % (
+ assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
- int(channel.result["code"]),
+ channel.code,
channel.result["body"],
)
@@ -263,6 +266,7 @@ class RestHelper:
appservice_user_id: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
expect_errcode: Optional[str] = None,
+ expect_additional_fields: Optional[dict] = None,
) -> None:
"""
Send a membership state event into a room.
@@ -303,14 +307,12 @@ class RestHelper:
self.site,
"PUT",
path,
- json.dumps(data).encode("utf8"),
+ data,
)
- assert (
- int(channel.result["code"]) == expect_code
- ), "Expected: %d, got: %d, resp: %r" % (
+ assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
- int(channel.result["code"]),
+ channel.code,
channel.result["body"],
)
@@ -323,6 +325,21 @@ class RestHelper:
channel.result["body"],
)
+ if expect_additional_fields is not None:
+ for expect_key, expect_value in expect_additional_fields.items():
+ assert expect_key in channel.json_body, "Expected field %s, got %s" % (
+ expect_key,
+ channel.json_body,
+ )
+ assert (
+ channel.json_body[expect_key] == expect_value
+ ), "Expected: %s at %s, got: %s, resp: %s" % (
+ expect_value,
+ expect_key,
+ channel.json_body[expect_key],
+ channel.json_body,
+ )
+
self.auth_user_id = temp_id
def send(
@@ -371,15 +388,13 @@ class RestHelper:
self.site,
"PUT",
path,
- json.dumps(content or {}).encode("utf8"),
+ content or {},
custom_headers=custom_headers,
)
- assert (
- int(channel.result["code"]) == expect_code
- ), "Expected: %d, got: %d, resp: %r" % (
+ assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
- int(channel.result["code"]),
+ channel.code,
channel.result["body"],
)
@@ -428,11 +443,9 @@ class RestHelper:
channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
- assert (
- int(channel.result["code"]) == expect_code
- ), "Expected: %d, got: %d, resp: %r" % (
+ assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
- int(channel.result["code"]),
+ channel.code,
channel.result["body"],
)
@@ -524,7 +537,7 @@ class RestHelper:
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
- int(channel.result["code"]),
+ channel.code,
channel.result["body"],
)
diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py
index ea9e5889bf..1062081a06 100644
--- a/tests/rest/media/v1/test_html_preview.py
+++ b/tests/rest/media/v1/test_html_preview.py
@@ -370,6 +370,64 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."})
+ def test_twitter_tag(self) -> None:
+ """Twitter card tags should be used if nothing else is available."""
+ html = b"""
+ <html>
+ <meta name="twitter:card" content="summary">
+ <meta name="twitter:description" content="Description">
+ <meta name="twitter:site" content="@matrixdotorg">
+ </html>
+ """
+ tree = decode_body(html, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
+ self.assertEqual(
+ og,
+ {
+ "og:title": None,
+ "og:description": "Description",
+ "og:site_name": "@matrixdotorg",
+ },
+ )
+
+ # But they shouldn't override Open Graph values.
+ html = b"""
+ <html>
+ <meta name="twitter:card" content="summary">
+ <meta name="twitter:description" content="Description">
+ <meta property="og:description" content="Real Description">
+ <meta name="twitter:site" content="@matrixdotorg">
+ <meta property="og:site_name" content="matrix.org">
+ </html>
+ """
+ tree = decode_body(html, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
+ self.assertEqual(
+ og,
+ {
+ "og:title": None,
+ "og:description": "Real Description",
+ "og:site_name": "matrix.org",
+ },
+ )
+
+ def test_nested_nodes(self) -> None:
+ """A body with some nested nodes. Tests that we iterate over children
+ in the right order (and don't reverse the order of the text)."""
+ html = b"""
+ <a href="somewhere">Welcome <b>the bold <u>and underlined text <svg>
+ with a cheeky SVG</svg></u> and <strong>some</strong> tail text</b></a>
+ """
+ tree = decode_body(html, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
+ self.assertEqual(
+ og,
+ {
+ "og:title": None,
+ "og:description": "Welcome\n\nthe bold\n\nand underlined text\n\nand\n\nsome\n\ntail text",
+ },
+ )
+
class MediaEncodingTestCase(unittest.TestCase):
def test_meta_charset(self) -> None:
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 7204b2dfe0..d18fc13c21 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -23,11 +23,13 @@ from urllib import parse
import attr
from parameterized import parameterized, parameterized_class
from PIL import Image as Image
+from typing_extensions import Literal
from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactor
+from synapse.api.errors import Codes
from synapse.events import EventBase
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.logging.context import make_deferred_yieldable
@@ -124,7 +126,9 @@ class _TestImage:
expected_scaled: The expected bytes from scaled thumbnailing, or None if
test should just check for a valid image returned.
expected_found: True if the file should exist on the server, or False if
- a 404 is expected.
+ a 404/400 is expected.
+ unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or
+ False if the thumbnailing should succeed or a normal 404 is expected.
"""
data: bytes
@@ -133,6 +137,7 @@ class _TestImage:
expected_cropped: Optional[bytes] = None
expected_scaled: Optional[bytes] = None
expected_found: bool = True
+ unable_to_thumbnail: bool = False
@parameterized_class(
@@ -190,6 +195,7 @@ class _TestImage:
b"image/gif",
b".gif",
expected_found=False,
+ unable_to_thumbnail=True,
),
),
],
@@ -364,18 +370,29 @@ class MediaRepoTests(unittest.HomeserverTestCase):
def test_thumbnail_crop(self) -> None:
"""Test that a cropped remote thumbnail is available."""
self._test_thumbnail(
- "crop", self.test_image.expected_cropped, self.test_image.expected_found
+ "crop",
+ self.test_image.expected_cropped,
+ expected_found=self.test_image.expected_found,
+ unable_to_thumbnail=self.test_image.unable_to_thumbnail,
)
def test_thumbnail_scale(self) -> None:
"""Test that a scaled remote thumbnail is available."""
self._test_thumbnail(
- "scale", self.test_image.expected_scaled, self.test_image.expected_found
+ "scale",
+ self.test_image.expected_scaled,
+ expected_found=self.test_image.expected_found,
+ unable_to_thumbnail=self.test_image.unable_to_thumbnail,
)
def test_invalid_type(self) -> None:
"""An invalid thumbnail type is never available."""
- self._test_thumbnail("invalid", None, False)
+ self._test_thumbnail(
+ "invalid",
+ None,
+ expected_found=False,
+ unable_to_thumbnail=self.test_image.unable_to_thumbnail,
+ )
@unittest.override_config(
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
@@ -384,7 +401,12 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"""
Override the config to generate only scaled thumbnails, but request a cropped one.
"""
- self._test_thumbnail("crop", None, False)
+ self._test_thumbnail(
+ "crop",
+ None,
+ expected_found=False,
+ unable_to_thumbnail=self.test_image.unable_to_thumbnail,
+ )
@unittest.override_config(
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
@@ -393,14 +415,22 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"""
Override the config to generate only cropped thumbnails, but request a scaled one.
"""
- self._test_thumbnail("scale", None, False)
+ self._test_thumbnail(
+ "scale",
+ None,
+ expected_found=False,
+ unable_to_thumbnail=self.test_image.unable_to_thumbnail,
+ )
def test_thumbnail_repeated_thumbnail(self) -> None:
"""Test that fetching the same thumbnail works, and deleting the on disk
thumbnail regenerates it.
"""
self._test_thumbnail(
- "scale", self.test_image.expected_scaled, self.test_image.expected_found
+ "scale",
+ self.test_image.expected_scaled,
+ expected_found=self.test_image.expected_found,
+ unable_to_thumbnail=self.test_image.unable_to_thumbnail,
)
if not self.test_image.expected_found:
@@ -457,8 +487,24 @@ class MediaRepoTests(unittest.HomeserverTestCase):
)
def _test_thumbnail(
- self, method: str, expected_body: Optional[bytes], expected_found: bool
+ self,
+ method: str,
+ expected_body: Optional[bytes],
+ expected_found: bool,
+ unable_to_thumbnail: bool = False,
) -> None:
+ """Test the given thumbnailing method works as expected.
+
+ Args:
+ method: The thumbnailing method to use (crop, scale).
+ expected_body: The expected bytes from thumbnailing, or None if
+ test should just check for a valid image.
+ expected_found: True if the file should exist on the server, or False if
+ a 404/400 is expected.
+ unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or
+ False if the thumbnailing should succeed or a normal 404 is expected.
+ """
+
params = "?width=32&height=32&method=" + method
channel = make_request(
self.reactor,
@@ -481,6 +527,12 @@ class MediaRepoTests(unittest.HomeserverTestCase):
if expected_found:
self.assertEqual(channel.code, 200)
+
+ self.assertEqual(
+ channel.headers.getRawHeaders(b"Cross-Origin-Resource-Policy"),
+ [b"cross-origin"],
+ )
+
if expected_body is not None:
self.assertEqual(
channel.result["body"], expected_body, channel.result["body"]
@@ -488,6 +540,16 @@ class MediaRepoTests(unittest.HomeserverTestCase):
else:
# ensure that the result is at least some valid image
Image.open(BytesIO(channel.result["body"]))
+ elif unable_to_thumbnail:
+ # A 400 with a JSON body.
+ self.assertEqual(channel.code, 400)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "errcode": "M_UNKNOWN",
+ "error": "Cannot find any thumbnails for the requested media ([b'example.com', b'12345']). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
+ },
+ )
else:
# A 404 with a JSON body.
self.assertEqual(channel.code, 404)
@@ -549,10 +611,26 @@ class MediaRepoTests(unittest.HomeserverTestCase):
[b"noindex, nofollow, noarchive, noimageindex"],
)
+ def test_cross_origin_resource_policy_header(self) -> None:
+ """
+ Test that the Cross-Origin-Resource-Policy header is set to "cross-origin"
+ allowing web clients to embed media from the downloads API.
+ """
+ channel = self._req(b"inline; filename=out" + self.test_image.extension)
-class TestSpamChecker:
+ headers = channel.headers
+
+ self.assertEqual(
+ headers.getRawHeaders(b"Cross-Origin-Resource-Policy"),
+ [b"cross-origin"],
+ )
+
+
+class TestSpamCheckerLegacy:
"""A spam checker module that rejects all media that includes the bytes
`evil`.
+
+ Uses the legacy Spam-Checker API.
"""
def __init__(self, config: Dict[str, Any], api: ModuleApi) -> None:
@@ -593,7 +671,7 @@ class TestSpamChecker:
return b"evil" in buf.getvalue()
-class SpamCheckerTestCase(unittest.HomeserverTestCase):
+class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase):
servlets = [
login.register_servlets,
admin.register_servlets,
@@ -617,7 +695,8 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
{
"spam_checker": [
{
- "module": TestSpamChecker.__module__ + ".TestSpamChecker",
+ "module": TestSpamCheckerLegacy.__module__
+ + ".TestSpamCheckerLegacy",
"config": {},
}
]
@@ -642,3 +721,62 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
self.helper.upload_media(
self.upload_resource, data, tok=self.tok, expect_code=400
)
+
+
+EVIL_DATA = b"Some evil data"
+EVIL_DATA_EXPERIMENT = b"Some evil data to trigger the experimental tuple API"
+
+
+class SpamCheckerTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ admin.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.user = self.register_user("user", "pass")
+ self.tok = self.login("user", "pass")
+
+ # Allow for uploading and downloading to/from the media repo
+ self.media_repo = hs.get_media_repository_resource()
+ self.download_resource = self.media_repo.children[b"download"]
+ self.upload_resource = self.media_repo.children[b"upload"]
+
+ hs.get_module_api().register_spam_checker_callbacks(
+ check_media_file_for_spam=self.check_media_file_for_spam
+ )
+
+ async def check_media_file_for_spam(
+ self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
+ ) -> Union[Codes, Literal["NOT_SPAM"]]:
+ buf = BytesIO()
+ await file_wrapper.write_chunks_to(buf.write)
+
+ if buf.getvalue() == EVIL_DATA:
+ return Codes.FORBIDDEN
+ elif buf.getvalue() == EVIL_DATA_EXPERIMENT:
+ return (Codes.FORBIDDEN, {})
+ else:
+ return "NOT_SPAM"
+
+ def test_upload_innocent(self) -> None:
+ """Attempt to upload some innocent data that should be allowed."""
+ self.helper.upload_media(
+ self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
+ )
+
+ def test_upload_ban(self) -> None:
+ """Attempt to upload some data that includes bytes "evil", which should
+ get rejected by the spam checker.
+ """
+
+ self.helper.upload_media(
+ self.upload_resource, EVIL_DATA, tok=self.tok, expect_code=400
+ )
+
+ self.helper.upload_media(
+ self.upload_resource,
+ EVIL_DATA_EXPERIMENT,
+ tok=self.tok,
+ expect_code=400,
+ )
diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
index da325955f8..c0a2501742 100644
--- a/tests/rest/test_health.py
+++ b/tests/rest/test_health.py
@@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
-
from synapse.rest.health import HealthResource
from tests import unittest
@@ -26,5 +24,5 @@ class HealthCheckTests(unittest.HomeserverTestCase):
def test_health(self) -> None:
channel = self.make_request("GET", "/health", shorthand=False)
- self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertEqual(channel.code, 200)
self.assertEqual(channel.result["body"], b"OK")
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index 11f78f52b8..2091b08d89 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
-
from twisted.web.resource import Resource
from synapse.rest.well_known import well_known_resource
@@ -38,7 +36,7 @@ class WellKnownTests(unittest.HomeserverTestCase):
"GET", "/.well-known/matrix/client", shorthand=False
)
- self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertEqual(channel.code, 200)
self.assertEqual(
channel.json_body,
{
@@ -57,7 +55,29 @@ class WellKnownTests(unittest.HomeserverTestCase):
"GET", "/.well-known/matrix/client", shorthand=False
)
- self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
+ self.assertEqual(channel.code, 404)
+
+ @unittest.override_config(
+ {
+ "public_baseurl": "https://tesths",
+ "default_identity_server": "https://testis",
+ "extra_well_known_client_content": {"custom": False},
+ }
+ )
+ def test_client_well_known_custom(self) -> None:
+ channel = self.make_request(
+ "GET", "/.well-known/matrix/client", shorthand=False
+ )
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "m.homeserver": {"base_url": "https://tesths/"},
+ "m.identity_server": {"base_url": "https://testis"},
+ "custom": False,
+ },
+ )
@unittest.override_config({"serve_server_wellknown": True})
def test_server_well_known(self) -> None:
@@ -65,7 +85,7 @@ class WellKnownTests(unittest.HomeserverTestCase):
"GET", "/.well-known/matrix/server", shorthand=False
)
- self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertEqual(channel.code, 200)
self.assertEqual(
channel.json_body,
{"m.server": "test:443"},
@@ -75,4 +95,4 @@ class WellKnownTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET", "/.well-known/matrix/server", shorthand=False
)
- self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
+ self.assertEqual(channel.code, 404)
diff --git a/tests/server.py b/tests/server.py
index b9f465971f..c447d5e4c4 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -25,6 +25,7 @@ from typing import (
Callable,
Dict,
Iterable,
+ List,
MutableMapping,
Optional,
Tuple,
@@ -43,6 +44,7 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import (
IAddress,
+ IConsumer,
IHostnameResolver,
IProtocol,
IPullProducer,
@@ -53,16 +55,16 @@ from twisted.internet.interfaces import (
ITransport,
)
from twisted.python.failure import Failure
-from twisted.test.proto_helpers import (
- AccumulatingProtocol,
- MemoryReactor,
- MemoryReactorClock,
-)
+from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
from twisted.web.http_headers import Headers
from twisted.web.resource import IResource
from twisted.web.server import Request, Site
from synapse.config.database import DatabaseConnectionConfig
+from synapse.events.presence_router import load_legacy_presence_router
+from synapse.events.spamcheck import load_legacy_spam_checkers
+from synapse.events.third_party_rules import load_legacy_third_party_event_rules
+from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.http.site import SynapseRequest
from synapse.logging.context import ContextResourceUsage
from synapse.server import HomeServer
@@ -96,6 +98,7 @@ class TimedOutException(Exception):
"""
+@implementer(IConsumer)
@attr.s(auto_attribs=True)
class FakeChannel:
"""
@@ -104,7 +107,7 @@ class FakeChannel:
"""
site: Union[Site, "FakeSite"]
- _reactor: MemoryReactor
+ _reactor: MemoryReactorClock
result: dict = attr.Factory(dict)
_ip: str = "127.0.0.1"
_producer: Optional[Union[IPullProducer, IPushProducer]] = None
@@ -122,8 +125,16 @@ class FakeChannel:
self._request = request
@property
- def json_body(self):
- return json.loads(self.text_body)
+ def json_body(self) -> JsonDict:
+ body = json.loads(self.text_body)
+ assert isinstance(body, dict)
+ return body
+
+ @property
+ def json_list(self) -> List[JsonDict]:
+ body = json.loads(self.text_body)
+ assert isinstance(body, list)
+ return body
@property
def text_body(self) -> str:
@@ -140,7 +151,7 @@ class FakeChannel:
return self.result.get("done", False)
@property
- def code(self):
+ def code(self) -> int:
if not self.result:
raise Exception("No result yet.")
return int(self.result["code"])
@@ -160,7 +171,7 @@ class FakeChannel:
self.result["reason"] = reason
self.result["headers"] = headers
- def write(self, content):
+ def write(self, content: bytes) -> None:
assert isinstance(content, bytes), "Should be bytes! " + repr(content)
if "body" not in self.result:
@@ -168,11 +179,16 @@ class FakeChannel:
self.result["body"] += content
- def registerProducer(self, producer, streaming):
+ # Type ignore: mypy doesn't like the fact that producer isn't an IProducer.
+ def registerProducer( # type: ignore[override]
+ self,
+ producer: Union[IPullProducer, IPushProducer],
+ streaming: bool,
+ ) -> None:
self._producer = producer
self.producerStreaming = streaming
- def _produce():
+ def _produce() -> None:
if self._producer:
self._producer.resumeProducing()
self._reactor.callLater(0.1, _produce)
@@ -180,31 +196,32 @@ class FakeChannel:
if not streaming:
self._reactor.callLater(0.0, _produce)
- def unregisterProducer(self):
+ def unregisterProducer(self) -> None:
if self._producer is None:
return
self._producer = None
- def requestDone(self, _self):
+ def requestDone(self, _self: Request) -> None:
self.result["done"] = True
if isinstance(_self, SynapseRequest):
+ assert _self.logcontext is not None
self.resource_usage = _self.logcontext.get_resource_usage()
- def getPeer(self):
+ def getPeer(self) -> IAddress:
# We give an address so that getClientAddress/getClientIP returns a non null entry,
# causing us to record the MAU
return address.IPv4Address("TCP", self._ip, 3423)
- def getHost(self):
+ def getHost(self) -> IAddress:
# this is called by Request.__init__ to configure Request.host.
return address.IPv4Address("TCP", "127.0.0.1", 8888)
- def isSecure(self):
+ def isSecure(self) -> bool:
return False
@property
- def transport(self):
+ def transport(self) -> "FakeChannel":
return self
def await_result(self, timeout_ms: int = 1000) -> None:
@@ -830,7 +847,6 @@ def setup_test_homeserver(
# Mock TLS
hs.tls_server_context_factory = Mock()
- hs.tls_client_options_factory = Mock()
hs.setup()
if homeserver_to_use == TestHomeServer:
@@ -901,4 +917,14 @@ def setup_test_homeserver(
# Make the threadpool and database transactions synchronous for testing.
_make_test_homeserver_synchronous(hs)
+ # Load any configured modules into the homeserver
+ module_api = hs.get_module_api()
+ for module, config in hs.config.modules.loaded_modules:
+ module(config=config, api=module_api)
+
+ load_legacy_spam_checkers(hs)
+ load_legacy_third_party_event_rules(hs)
+ load_legacy_presence_router(hs)
+ load_legacy_password_auth_providers(hs)
+
return hs
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 07e29788e5..bf403045e9 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -11,16 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType
from synapse.api.errors import ResourceLimitError
from synapse.rest import admin
from synapse.rest.client import login, room, sync
+from synapse.server import HomeServer
from synapse.server_notices.resource_limits_server_notices import (
ResourceLimitsServerNotices,
)
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
@@ -52,7 +55,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
return config
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.server_notices_sender = self.hs.get_server_notices_sender()
# relying on [1] is far from ideal, but the only case where
@@ -96,7 +99,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
"""Test when user has blocked notice, but should have it removed"""
- self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
+ self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ return_value=make_awaitable(None)
+ )
mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
@@ -112,7 +117,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
Test when user has blocked notice, but notice ought to be there (NOOP)
"""
- self._rlsn._auth.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock(
return_value=make_awaitable(None),
side_effect=ResourceLimitError(403, "foo"),
)
@@ -132,7 +137,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
Test when user does not have blocked notice, but should have one
"""
- self._rlsn._auth.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock(
return_value=make_awaitable(None),
side_effect=ResourceLimitError(403, "foo"),
)
@@ -145,7 +150,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
Test when user does not have blocked notice, nor should they (NOOP)
"""
- self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
+ self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ return_value=make_awaitable(None)
+ )
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -156,7 +163,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user is not part of the MAU cohort - this should not ever
happen - but ...
"""
- self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
+ self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ return_value=make_awaitable(None)
+ )
self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=make_awaitable(None)
)
@@ -170,7 +179,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test that when server is over MAU limit and alerting is suppressed, then
an alert message is not sent into the room
"""
- self._rlsn._auth.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock(
return_value=make_awaitable(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
@@ -185,7 +194,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
Test that when a server is disabled, that MAU limit alerting is ignored.
"""
- self._rlsn._auth.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock(
return_value=make_awaitable(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
@@ -202,7 +211,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
When the room is already in a blocked state, test that when alerting
is suppressed that the room is returned to an unblocked state.
"""
- self._rlsn._auth.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock(
return_value=make_awaitable(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
@@ -245,7 +254,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
c["admin_contact"] = "mailto:user@test.com"
return c
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main
self.server_notices_sender = self.hs.get_server_notices_sender()
self.server_notices_manager = self.hs.get_server_notices_manager()
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 8370a27195..2e3f2318d9 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -13,7 +13,17 @@
# limitations under the License.
import itertools
-from typing import List
+from typing import (
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Set,
+ Tuple,
+ TypeVar,
+)
import attr
@@ -22,13 +32,13 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.room_versions import RoomVersions
from synapse.event_auth import auth_types_for_event
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
from synapse.state.v2 import (
_get_auth_chain_difference,
lexicographical_topological_sort,
resolve_events_with_store,
)
-from synapse.types import EventID
+from synapse.types import EventID, StateMap
from tests import unittest
@@ -48,7 +58,7 @@ ORIGIN_SERVER_TS = 0
class FakeClock:
- def sleep(self, msec):
+ def sleep(self, msec: float) -> "defer.Deferred[None]":
return defer.succeed(None)
@@ -60,7 +70,14 @@ class FakeEvent:
as domain.
"""
- def __init__(self, id, sender, type, state_key, content):
+ def __init__(
+ self,
+ id: str,
+ sender: str,
+ type: str,
+ state_key: Optional[str],
+ content: Mapping[str, object],
+ ):
self.node_id = id
self.event_id = EventID(id, "example.com").to_string()
self.sender = sender
@@ -69,12 +86,12 @@ class FakeEvent:
self.content = content
self.room_id = ROOM_ID
- def to_event(self, auth_events, prev_events):
+ def to_event(self, auth_events: List[str], prev_events: List[str]) -> EventBase:
"""Given the auth_events and prev_events, convert to a Frozen Event
Args:
- auth_events (list[str]): list of event_ids
- prev_events (list[str]): list of event_ids
+ auth_events: list of event_ids
+ prev_events: list of event_ids
Returns:
FrozenEvent
@@ -164,7 +181,7 @@ INITIAL_EDGES = ["START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE"]
class StateTestCase(unittest.TestCase):
- def test_ban_vs_pl(self):
+ def test_ban_vs_pl(self) -> None:
events = [
FakeEvent(
id="PA",
@@ -202,7 +219,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
- def test_join_rule_evasion(self):
+ def test_join_rule_evasion(self) -> None:
events = [
FakeEvent(
id="JR",
@@ -226,7 +243,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
- def test_offtopic_pl(self):
+ def test_offtopic_pl(self) -> None:
events = [
FakeEvent(
id="PA",
@@ -257,7 +274,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
- def test_topic_basic(self):
+ def test_topic_basic(self) -> None:
events = [
FakeEvent(
id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
@@ -297,7 +314,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
- def test_topic_reset(self):
+ def test_topic_reset(self) -> None:
events = [
FakeEvent(
id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
@@ -327,7 +344,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
- def test_topic(self):
+ def test_topic(self) -> None:
events = [
FakeEvent(
id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
@@ -380,7 +397,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
- def test_mainline_sort(self):
+ def test_mainline_sort(self) -> None:
"""Tests that the mainline ordering works correctly."""
events = [
@@ -434,22 +451,26 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
- def do_check(self, events, edges, expected_state_ids):
+ def do_check(
+ self,
+ events: List[FakeEvent],
+ edges: List[List[str]],
+ expected_state_ids: List[str],
+ ) -> None:
"""Take a list of events and edges and calculate the state of the
graph at END, and asserts it matches `expected_state_ids`
Args:
- events (list[FakeEvent])
- edges (list[list[str]]): A list of chains of event edges, e.g.
+ events
+ edges: A list of chains of event edges, e.g.
`[[A, B, C]]` are edges A->B and B->C.
- expected_state_ids (list[str]): The expected state at END, (excluding
+ expected_state_ids: The expected state at END, (excluding
the keys that haven't changed since START).
"""
# We want to sort the events into topological order for processing.
- graph = {}
+ graph: Dict[str, Set[str]] = {}
- # node_id -> FakeEvent
- fake_event_map = {}
+ fake_event_map: Dict[str, FakeEvent] = {}
for ev in itertools.chain(INITIAL_EVENTS, events):
graph[ev.node_id] = set()
@@ -462,10 +483,8 @@ class StateTestCase(unittest.TestCase):
for a, b in pairwise(edge_list):
graph[a].add(b)
- # event_id -> FrozenEvent
- event_map = {}
- # node_id -> state
- state_at_event = {}
+ event_map: Dict[str, EventBase] = {}
+ state_at_event: Dict[str, StateMap[str]] = {}
# We copy the map as the sort consumes the graph
graph_copy = {k: set(v) for k, v in graph.items()}
@@ -476,6 +495,7 @@ class StateTestCase(unittest.TestCase):
prev_events = list(graph[node_id])
+ state_before: StateMap[str]
if len(prev_events) == 0:
state_before = {}
elif len(prev_events) == 1:
@@ -496,7 +516,16 @@ class StateTestCase(unittest.TestCase):
if fake_event.state_key is not None:
state_after[(fake_event.type, fake_event.state_key)] = event_id
- auth_types = set(auth_types_for_event(RoomVersions.V6, fake_event))
+ # This type ignore is a bit sad. Things we have tried:
+ # 1. Define a `GenericEvent` Protocol satisfied by FakeEvent, EventBase and
+ # EventBuilder. But this is Hard because the relevant attributes are
+ # DictProperty[T] descriptors on EventBase but normal Ts on FakeEvent.
+ # 2. Define a `GenericEvent` Protocol describing `FakeEvent` only, and
+ # change this function to accept Union[Event, EventBase, EventBuilder].
+ # This seems reasonable to me, but mypy isn't happy. I think that's
+ # a mypy bug, see https://github.com/python/mypy/issues/5570
+ # Instead, resort to a type-ignore.
+ auth_types = set(auth_types_for_event(RoomVersions.V6, fake_event)) # type: ignore[arg-type]
auth_events = []
for key in auth_types:
@@ -530,8 +559,14 @@ class StateTestCase(unittest.TestCase):
class LexicographicalTestCase(unittest.TestCase):
- def test_simple(self):
- graph = {"l": {"o"}, "m": {"n", "o"}, "n": {"o"}, "o": set(), "p": {"o"}}
+ def test_simple(self) -> None:
+ graph: Dict[str, Set[str]] = {
+ "l": {"o"},
+ "m": {"n", "o"},
+ "n": {"o"},
+ "o": set(),
+ "p": {"o"},
+ }
res = list(lexicographical_topological_sort(graph, key=lambda x: x))
@@ -539,7 +574,7 @@ class LexicographicalTestCase(unittest.TestCase):
class SimpleParamStateTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
# We build up a simple DAG.
event_map = {}
@@ -627,7 +662,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
]
}
- def test_event_map_none(self):
+ def test_event_map_none(self) -> None:
# Test that we correctly handle passing `None` as the event_map
state_d = resolve_events_with_store(
@@ -649,7 +684,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
events.
"""
- def test_simple(self):
+ def test_simple(self) -> None:
# Test getting the auth difference for a simple chain with a single
# unpersisted event:
#
@@ -695,7 +730,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
self.assertEqual(difference, {c.event_id})
- def test_multiple_unpersisted_chain(self):
+ def test_multiple_unpersisted_chain(self) -> None:
# Test getting the auth difference for a simple chain with multiple
# unpersisted events:
#
@@ -752,7 +787,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
self.assertEqual(difference, {d.event_id, c.event_id})
- def test_unpersisted_events_different_sets(self):
+ def test_unpersisted_events_different_sets(self) -> None:
# Test getting the auth difference for with multiple unpersisted events
# in different branches:
#
@@ -820,7 +855,10 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
self.assertEqual(difference, {d.event_id, e.event_id})
-def pairwise(iterable):
+T = TypeVar("T")
+
+
+def pairwise(iterable: Iterable[T]) -> Iterable[Tuple[T, T]]:
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = itertools.tee(iterable)
next(b, None)
@@ -829,24 +867,26 @@ def pairwise(iterable):
@attr.s
class TestStateResolutionStore:
- event_map = attr.ib()
+ event_map: Dict[str, EventBase] = attr.ib()
- def get_events(self, event_ids, allow_rejected=False):
+ def get_events(
+ self, event_ids: Collection[str], allow_rejected: bool = False
+ ) -> "defer.Deferred[Dict[str, EventBase]]":
"""Get events from the database
Args:
- event_ids (list): The event_ids of the events to fetch
- allow_rejected (bool): If True return rejected events.
+ event_ids: The event_ids of the events to fetch
+ allow_rejected: If True return rejected events.
Returns:
- Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
+ Dict from event_id to event.
"""
return defer.succeed(
{eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
)
- def _get_auth_chain(self, event_ids: List[str]) -> List[str]:
+ def _get_auth_chain(self, event_ids: Iterable[str]) -> List[str]:
"""Gets the full auth chain for a set of events (including rejected
events).
@@ -880,7 +920,9 @@ class TestStateResolutionStore:
return list(result)
- def get_auth_chain_difference(self, room_id, auth_sets):
+ def get_auth_chain_difference(
+ self, room_id: str, auth_sets: List[Set[str]]
+ ) -> "defer.Deferred[Set[str]]":
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
common = set(chains[0]).intersection(*chains[1:])
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 38963ce4a7..46d829b062 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -143,7 +143,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
self.event_id = res["event_id"]
# Reset the event cache so the tests start with it empty
- self.store._get_event_cache.clear()
+ self.get_success(self.store._get_event_cache.clear())
def test_simple(self):
"""Test that we cache events that we pull from the DB."""
@@ -160,7 +160,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
"""
# Reset the event cache
- self.store._get_event_cache.clear()
+ self.get_success(self.store._get_event_cache.clear())
with LoggingContext("test") as ctx:
# We keep hold of the event event though we never use it.
@@ -170,7 +170,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
# Reset the event cache
- self.store._get_event_cache.clear()
+ self.get_success(self.store._get_event_cache.clear())
with LoggingContext("test") as ctx:
self.get_success(self.store.get_event(self.event_id))
@@ -345,7 +345,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
self.event_id = res["event_id"]
# Reset the event cache so the tests start with it empty
- self.store._get_event_cache.clear()
+ self.get_success(self.store._get_event_cache.clear())
@contextmanager
def blocking_get_event_calls(
diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py
index 9abd0cb446..1edb619630 100644
--- a/tests/storage/databases/main/test_room.py
+++ b/tests/storage/databases/main/test_room.py
@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
+
+from synapse.api.constants import RoomTypes
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.storage.databases.main.room import _BackgroundUpdates
@@ -91,3 +94,69 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
)
)
self.assertEqual(room_creator_after, self.user_id)
+
+ def test_background_add_room_type_column(self):
+ """Test that the background update to populate the `room_type` column in
+ `room_stats_state` works properly.
+ """
+
+ # Create a room without a type
+ room_id = self._generate_room()
+
+ # Get event_id of the m.room.create event
+ event_id = self.get_success(
+ self.store.db_pool.simple_select_one_onecol(
+ table="current_state_events",
+ keyvalues={
+ "room_id": room_id,
+ "type": "m.room.create",
+ },
+ retcol="event_id",
+ )
+ )
+
+ # Fake a room creation event with a room type
+ event = {
+ "content": {
+ "creator": "@user:server.org",
+ "room_version": "9",
+ "type": RoomTypes.SPACE,
+ },
+ "type": "m.room.create",
+ }
+ self.get_success(
+ self.store.db_pool.simple_update(
+ table="event_json",
+ keyvalues={"event_id": event_id},
+ updatevalues={"json": json.dumps(event)},
+ desc="test",
+ )
+ )
+
+ # Insert and run the background update
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN,
+ "progress_json": "{}",
+ },
+ )
+ )
+
+ # ... and tell the DataStore that it hasn't finished all updates yet
+ self.store.db_pool.updates._all_done = False
+
+ # Now let's actually drive the updates to completion
+ self.wait_for_background_updates()
+
+ # Make sure the background update filled in the room type
+ room_type_after = self.get_success(
+ self.store.db_pool.simple_select_one_onecol(
+ table="room_stats_state",
+ keyvalues={"room_id": room_id},
+ retcol="room_type",
+ allow_none=True,
+ )
+ )
+ self.assertEqual(room_type_after, RoomTypes.SPACE)
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 0f9add4841..fc43d7edd1 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -12,143 +12,175 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+from synapse.rest import admin
+from synapse.rest.client import login, room
+from synapse.server import HomeServer
from synapse.storage.databases.main.event_push_actions import NotifCounts
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
USER_ID = "@user:example.com"
-PlAIN_NOTIF = ["notify", {"set_tweak": "highlight", "value": False}]
-HIGHLIGHT = [
- "notify",
- {"set_tweak": "sound", "value": "default"},
- {"set_tweak": "highlight"},
-]
-
class EventPushActionsStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
- self.persist_events_store = hs.get_datastores().persist_events
+ persist_events_store = hs.get_datastores().persist_events
+ assert persist_events_store is not None
+ self.persist_events_store = persist_events_store
- def test_get_unread_push_actions_for_user_in_range_for_http(self):
+ def test_get_unread_push_actions_for_user_in_range_for_http(self) -> None:
self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_http(
USER_ID, 0, 1000, 20
)
)
- def test_get_unread_push_actions_for_user_in_range_for_email(self):
+ def test_get_unread_push_actions_for_user_in_range_for_email(self) -> None:
self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_email(
USER_ID, 0, 1000, 20
)
)
- def test_count_aggregation(self):
- room_id = "!foo:example.com"
- user_id = "@user1235:example.com"
+ def test_count_aggregation(self) -> None:
+ # Create a user to receive notifications and send receipts.
+ user_id = self.register_user("user1235", "pass")
+ token = self.login("user1235", "pass")
+
+ # And another users to send events.
+ other_id = self.register_user("other", "pass")
+ other_token = self.login("other", "pass")
+
+ # Create a room and put both users in it.
+ room_id = self.helper.create_room_as(user_id, tok=token)
+ self.helper.join(room_id, other_id, tok=other_token)
- def _assert_counts(noitf_count, highlight_count):
+ last_event_id: str
+
+ def _assert_counts(noitf_count: int, highlight_count: int) -> None:
counts = self.get_success(
self.store.db_pool.runInteraction(
- "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
+ "get-unread-counts",
+ self.store._get_unread_counts_by_receipt_txn,
+ room_id,
+ user_id,
)
)
self.assertEqual(
counts,
NotifCounts(
notify_count=noitf_count,
- unread_count=0, # Unread counts are tested in the sync tests.
+ unread_count=0,
highlight_count=highlight_count,
),
)
- def _inject_actions(stream, action):
- event = Mock()
- event.room_id = room_id
- event.event_id = "$test:example.com"
- event.internal_metadata.stream_ordering = stream
- event.internal_metadata.is_outlier.return_value = False
- event.depth = stream
-
- self.get_success(
- self.store.add_push_actions_to_staging(
- event.event_id,
- {user_id: action},
- False,
- )
- )
- self.get_success(
- self.store.db_pool.runInteraction(
- "",
- self.persist_events_store._set_push_actions_for_event_and_users_txn,
- [(event, None)],
- [(event, None)],
- )
+ def _create_event(highlight: bool = False) -> str:
+ result = self.helper.send_event(
+ room_id,
+ type="m.room.message",
+ content={"msgtype": "m.text", "body": user_id if highlight else "msg"},
+ tok=other_token,
)
+ nonlocal last_event_id
+ last_event_id = result["event_id"]
+ return last_event_id
- def _rotate(stream):
- self.get_success(
- self.store.db_pool.runInteraction(
- "", self.store._rotate_notifs_before_txn, stream
- )
- )
+ def _rotate() -> None:
+ self.get_success(self.store._rotate_notifs())
- def _mark_read(stream, depth):
+ def _mark_read(event_id: str) -> None:
self.get_success(
- self.store.db_pool.runInteraction(
- "",
- self.store._remove_old_push_actions_before_txn,
+ self.store.insert_receipt(
room_id,
- user_id,
- stream,
+ "m.read",
+ user_id=user_id,
+ event_ids=[event_id],
+ data={},
)
)
_assert_counts(0, 0)
- _inject_actions(1, PlAIN_NOTIF)
+ _create_event()
_assert_counts(1, 0)
- _rotate(2)
+ _rotate()
_assert_counts(1, 0)
- _inject_actions(3, PlAIN_NOTIF)
+ event_id = _create_event()
_assert_counts(2, 0)
- _rotate(4)
+ _rotate()
_assert_counts(2, 0)
- _inject_actions(5, PlAIN_NOTIF)
- _mark_read(3, 3)
+ _create_event()
+ _mark_read(event_id)
_assert_counts(1, 0)
- _mark_read(5, 5)
+ _mark_read(last_event_id)
_assert_counts(0, 0)
- _inject_actions(6, PlAIN_NOTIF)
- _rotate(7)
+ _create_event()
+ _rotate()
+ _assert_counts(1, 0)
- self.get_success(
- self.store.db_pool.simple_delete(
- table="event_push_actions", keyvalues={"1": 1}, desc=""
+ # Delete old event push actions, this should not affect the (summarised) count.
+ #
+ # All event push actions are kept for 24 hours, so need to move forward
+ # in time.
+ self.pump(60 * 60 * 24)
+ self.get_success(self.store._remove_old_push_actions_that_have_rotated())
+ # Double check that the event push actions have been cleared (i.e. that
+ # any results *must* come from the summary).
+ result = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="event_push_actions",
+ keyvalues={"1": 1},
+ retcols=("event_id",),
+ desc="",
)
)
-
+ self.assertEqual(result, [])
_assert_counts(1, 0)
- _mark_read(7, 7)
+ _mark_read(last_event_id)
_assert_counts(0, 0)
- _inject_actions(8, HIGHLIGHT)
+ event_id = _create_event(True)
_assert_counts(1, 1)
- _rotate(9)
+ _rotate()
_assert_counts(1, 1)
- _rotate(10)
+
+ # Check that adding another notification and rotating after highlight
+ # works.
+ _create_event()
+ _rotate()
+ _assert_counts(2, 1)
+
+ # Check that sending read receipts at different points results in the
+ # right counts.
+ _mark_read(event_id)
+ _assert_counts(1, 0)
+ _mark_read(last_event_id)
+ _assert_counts(0, 0)
+
+ _create_event(True)
_assert_counts(1, 1)
+ _mark_read(last_event_id)
+ _assert_counts(0, 0)
+ _rotate()
+ _assert_counts(0, 0)
- def test_find_first_stream_ordering_after_ts(self):
- def add_event(so, ts):
+ def test_find_first_stream_ordering_after_ts(self) -> None:
+ def add_event(so: int, ts: int) -> None:
self.get_success(
self.store.db_pool.simple_insert(
"events",
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 2ff88e64a5..3ce4f35cb7 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -70,7 +70,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
def persist_event(self, event, state=None):
"""Persist the event, with optional state"""
context = self.get_success(
- self.state.compute_event_context(event, state_ids_before_event=state)
+ self.state.compute_event_context(
+ event,
+ state_ids_before_event=state,
+ partial_state=None if state is None else False,
+ )
)
self.get_success(self._persistence.persist_event(event, context))
@@ -148,6 +152,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
self.state.compute_event_context(
remote_event_2,
state_ids_before_event=state_before_gap,
+ partial_state=False,
)
)
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 8dfaa0559b..9c1182ed16 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -115,6 +115,6 @@ class PurgeTests(HomeserverTestCase):
)
# The events aren't found.
- self.store._invalidate_get_event_cache(create_event.event_id)
+ self.store._invalidate_local_get_event_cache(create_event.event_id)
self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/storage/test_receipts.py
index 19f57115a1..c89bfff241 100644
--- a/tests/replication/slave/storage/test_receipts.py
+++ b/tests/storage/test_receipts.py
@@ -12,24 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
from synapse.api.constants import ReceiptTypes
-from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.types import UserID, create_requester
from tests.test_utils.event_injection import create_event
-
-from ._base import BaseSlavedStoreTestCase
+from tests.unittest import HomeserverTestCase
OTHER_USER_ID = "@other:test"
OUR_USER_ID = "@our:test"
-class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
+class ReceiptTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, homeserver) -> None:
+ super().prepare(reactor, clock, homeserver)
- STORE_TYPE = SlavedReceiptsStore
+ self.store = homeserver.get_datastores().main
- def prepare(self, reactor, clock, homeserver):
- super().prepare(reactor, clock, homeserver)
self.room_creator = homeserver.get_room_creation_handler()
self.persist_event_storage_controller = (
self.hs.get_storage_controllers().persistence
@@ -85,32 +84,42 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
)
)
- def test_return_empty_with_no_data(self):
+ def test_return_empty_with_no_data(self) -> None:
res = self.get_success(
- self.master_store.get_receipts_for_user(
- OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
+ self.store.get_receipts_for_user(
+ OUR_USER_ID,
+ [
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ],
)
)
self.assertEqual(res, {})
res = self.get_success(
- self.master_store.get_receipts_for_user_with_orderings(
+ self.store.get_receipts_for_user_with_orderings(
OUR_USER_ID,
- [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
+ [
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ],
)
)
self.assertEqual(res, {})
res = self.get_success(
- self.master_store.get_last_receipt_event_id_for_user(
+ self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID,
self.room_id1,
- [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
+ [
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ],
)
)
self.assertEqual(res, None)
- def test_get_receipts_for_user(self):
+ def test_get_receipts_for_user(self) -> None:
# Send some events into the first room
event1_1_id = self.create_and_send_event(
self.room_id1, UserID.from_string(OTHER_USER_ID)
@@ -121,20 +130,20 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
# Send public read receipt for the first event
self.get_success(
- self.master_store.insert_receipt(
+ self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {}
)
)
# Send private read receipt for the second event
self.get_success(
- self.master_store.insert_receipt(
+ self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {}
)
)
# Test we get the latest event when we want both private and public receipts
res = self.get_success(
- self.master_store.get_receipts_for_user(
+ self.store.get_receipts_for_user(
OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
)
)
@@ -142,26 +151,24 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
# Test we get the older event when we want only public receipt
res = self.get_success(
- self.master_store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ])
+ self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ])
)
self.assertEqual(res, {self.room_id1: event1_1_id})
# Test we get the latest event when we want only the public receipt
res = self.get_success(
- self.master_store.get_receipts_for_user(
- OUR_USER_ID, [ReceiptTypes.READ_PRIVATE]
- )
+ self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ_PRIVATE])
)
self.assertEqual(res, {self.room_id1: event1_2_id})
# Test receipt updating
self.get_success(
- self.master_store.insert_receipt(
+ self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {}
)
)
res = self.get_success(
- self.master_store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ])
+ self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ])
)
self.assertEqual(res, {self.room_id1: event1_2_id})
@@ -172,18 +179,18 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
# Test new room is reflected in what the method returns
self.get_success(
- self.master_store.insert_receipt(
+ self.store.insert_receipt(
self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {}
)
)
res = self.get_success(
- self.master_store.get_receipts_for_user(
+ self.store.get_receipts_for_user(
OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
)
)
self.assertEqual(res, {self.room_id1: event1_2_id, self.room_id2: event2_1_id})
- def test_get_last_receipt_event_id_for_user(self):
+ def test_get_last_receipt_event_id_for_user(self) -> None:
# Send some events into the first room
event1_1_id = self.create_and_send_event(
self.room_id1, UserID.from_string(OTHER_USER_ID)
@@ -194,20 +201,20 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
# Send public read receipt for the first event
self.get_success(
- self.master_store.insert_receipt(
+ self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {}
)
)
# Send private read receipt for the second event
self.get_success(
- self.master_store.insert_receipt(
+ self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {}
)
)
# Test we get the latest event when we want both private and public receipts
res = self.get_success(
- self.master_store.get_last_receipt_event_id_for_user(
+ self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID,
self.room_id1,
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
@@ -217,7 +224,7 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
# Test we get the older event when we want only public receipt
res = self.get_success(
- self.master_store.get_last_receipt_event_id_for_user(
+ self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
)
)
@@ -225,7 +232,7 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
# Test we get the latest event when we want only the private receipt
res = self.get_success(
- self.master_store.get_last_receipt_event_id_for_user(
+ self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE]
)
)
@@ -233,12 +240,12 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
# Test receipt updating
self.get_success(
- self.master_store.insert_receipt(
+ self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {}
)
)
res = self.get_success(
- self.master_store.get_last_receipt_event_id_for_user(
+ self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
)
)
@@ -251,12 +258,12 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
# Test new room is reflected in what the method returns
self.get_success(
- self.master_store.insert_receipt(
+ self.store.insert_receipt(
self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {}
)
)
res = self.get_success(
- self.master_store.get_last_receipt_event_id_for_user(
+ self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID,
self.room_id2,
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 3c79dabc9f..3405efb6a8 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions
from synapse.types import RoomAlias, RoomID, UserID
@@ -65,71 +64,3 @@ class RoomStoreTestCase(HomeserverTestCase):
self.assertIsNone(
(self.get_success(self.store.get_room_with_stats("!uknown:test"))),
)
-
-
-class RoomEventsStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
- # Room events need the full datastore, for persist_event() and
- # get_room_state()
- self.store = hs.get_datastores().main
- self._storage_controllers = hs.get_storage_controllers()
- self.event_factory = hs.get_event_factory()
-
- self.room = RoomID.from_string("!abcde:test")
-
- self.get_success(
- self.store.store_room(
- self.room.to_string(),
- room_creator_user_id="@creator:text",
- is_public=True,
- room_version=RoomVersions.V1,
- )
- )
-
- def inject_room_event(self, **kwargs):
- self.get_success(
- self._storage_controllers.persistence.persist_event(
- self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
- )
- )
-
- def STALE_test_room_name(self):
- name = "A-Room-Name"
-
- self.inject_room_event(
- etype=EventTypes.Name, name=name, content={"name": name}, depth=1
- )
-
- state = self.get_success(
- self._storage_controllers.state.get_current_state(
- room_id=self.room.to_string()
- )
- )
-
- self.assertEqual(1, len(state))
- self.assertObjectHasAttributes(
- {"type": "m.room.name", "room_id": self.room.to_string(), "name": name},
- state[0],
- )
-
- def STALE_test_room_topic(self):
- topic = "A place for things"
-
- self.inject_room_event(
- etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
- )
-
- state = self.get_success(
- self._storage_controllers.state.get_current_state(
- room_id=self.room.to_string()
- )
- )
-
- self.assertEqual(1, len(state))
- self.assertObjectHasAttributes(
- {"type": "m.room.topic", "room_id": self.room.to_string(), "topic": topic},
- state[0],
- )
-
- # Not testing the various 'level' methods for now because there's lots
- # of them and need coalescing; see JIRA SPEC-11
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 1218786d79..8794401823 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -110,60 +110,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# It now knows about Charlie's server.
self.assertEqual(self.store._known_servers_count, 2)
- def test_get_joined_users_from_context(self) -> None:
- room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
- bob_event = self.get_success(
- event_injection.inject_member_event(
- self.hs, room, self.u_bob, Membership.JOIN
- )
- )
-
- # first, create a regular event
- event, context = self.get_success(
- event_injection.create_event(
- self.hs,
- room_id=room,
- sender=self.u_alice,
- prev_event_ids=[bob_event.event_id],
- type="m.test.1",
- content={},
- )
- )
-
- users = self.get_success(
- self.store.get_joined_users_from_context(event, context)
- )
- self.assertEqual(users.keys(), {self.u_alice, self.u_bob})
-
- # Regression test for #7376: create a state event whose key matches bob's
- # user_id, but which is *not* a membership event, and persist that; then check
- # that `get_joined_users_from_context` returns the correct users for the next event.
- non_member_event = self.get_success(
- event_injection.inject_event(
- self.hs,
- room_id=room,
- sender=self.u_bob,
- prev_event_ids=[bob_event.event_id],
- type="m.test.2",
- state_key=self.u_bob,
- content={},
- )
- )
- event, context = self.get_success(
- event_injection.create_event(
- self.hs,
- room_id=room,
- sender=self.u_alice,
- prev_event_ids=[non_member_event.event_id],
- type="m.test.3",
- content={},
- )
- )
- users = self.get_success(
- self.store.get_joined_users_from_context(event, context)
- )
- self.assertEqual(users.keys(), {self.u_alice, self.u_bob})
-
def test__null_byte_in_display_name_properly_handled(self) -> None:
room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
@@ -212,6 +158,75 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# Check that alice's display name is now None
self.assertEqual(row[0]["display_name"], None)
+ def test_room_is_locally_forgotten(self) -> None:
+ """Test that when the last local user has forgotten a room it is known as forgotten."""
+ # join two local and one remote user
+ self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
+ self.get_success(
+ event_injection.inject_member_event(self.hs, self.room, self.u_bob, "join")
+ )
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, self.room, self.u_charlie.to_string(), "join"
+ )
+ )
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
+ # local users leave the room and the room is not forgotten
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, self.room, self.u_alice, "leave"
+ )
+ )
+ self.get_success(
+ event_injection.inject_member_event(self.hs, self.room, self.u_bob, "leave")
+ )
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
+ # first user forgets the room, room is not forgotten
+ self.get_success(self.store.forget(self.u_alice, self.room))
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
+ # second (last local) user forgets the room and the room is forgotten
+ self.get_success(self.store.forget(self.u_bob, self.room))
+ self.assertTrue(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
+ def test_join_locally_forgotten_room(self) -> None:
+ """Tests if a user joins a forgotten room the room is not forgotten anymore."""
+ self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
+ # after leaving and forget the room, it is forgotten
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, self.room, self.u_alice, "leave"
+ )
+ )
+ self.get_success(self.store.forget(self.u_alice, self.room))
+ self.assertTrue(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
+ # after rejoin the room is not forgotten anymore
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, self.room, self.u_alice, "join"
+ )
+ )
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 8043bdbde2..5564161750 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -369,8 +369,8 @@ class StateStoreTestCase(HomeserverTestCase):
state_dict_ids = cache_entry.value
self.assertEqual(cache_entry.full, False)
- self.assertEqual(cache_entry.known_absent, {(e1.type, e1.state_key)})
- self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
+ self.assertEqual(cache_entry.known_absent, set())
+ self.assertDictEqual(state_dict_ids, {})
############################################
# test that things work with a partial cache
@@ -387,7 +387,7 @@ class StateStoreTestCase(HomeserverTestCase):
)
self.assertEqual(is_all, False)
- self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
+ self.assertDictEqual({}, state_dict)
room_id = self.room.to_string()
(state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
@@ -412,7 +412,7 @@ class StateStoreTestCase(HomeserverTestCase):
)
self.assertEqual(is_all, False)
- self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
+ self.assertDictEqual({}, state_dict)
(state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
@@ -443,7 +443,7 @@ class StateStoreTestCase(HomeserverTestCase):
)
self.assertEqual(is_all, False)
- self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
+ self.assertDictEqual({}, state_dict)
(state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index e2c506e5a4..e42d7b9ba0 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -1,4 +1,4 @@
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,15 +13,50 @@
# limitations under the License.
import unittest
-from typing import Optional
+from typing import Collection, Dict, Iterable, List, Optional
+
+from parameterized import parameterized
from synapse import event_auth
from synapse.api.constants import EventContentFields
-from synapse.api.errors import AuthError
-from synapse.api.room_versions import RoomVersions
+from synapse.api.errors import AuthError, SynapseError
+from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
from synapse.events import EventBase, make_event_from_dict
+from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import JsonDict, get_domain_from_id
+from tests.test_utils import get_awaitable_result
+
+
+class _StubEventSourceStore:
+ """A stub implementation of the EventSourceStore"""
+
+ def __init__(self):
+ self._store: Dict[str, EventBase] = {}
+
+ def add_event(self, event: EventBase):
+ self._store[event.event_id] = event
+
+ def add_events(self, events: Iterable[EventBase]):
+ for event in events:
+ self._store[event.event_id] = event
+
+ async def get_events(
+ self,
+ event_ids: Collection[str],
+ redact_behaviour: EventRedactBehaviour,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ ) -> Dict[str, EventBase]:
+ assert allow_rejected
+ assert not get_prev_content
+ assert redact_behaviour == EventRedactBehaviour.as_is
+ results = {}
+ for e in event_ids:
+ if e in self._store:
+ results[e] = self._store[e]
+ return results
+
class EventAuthTestCase(unittest.TestCase):
def test_rejected_auth_events(self):
@@ -30,40 +65,176 @@ class EventAuthTestCase(unittest.TestCase):
"""
creator = "@creator:example.com"
auth_events = [
- _create_event(creator),
- _join_event(creator),
+ _create_event(RoomVersions.V9, creator),
+ _join_event(RoomVersions.V9, creator),
]
+ event_store = _StubEventSourceStore()
+ event_store.add_events(auth_events)
+
# creator should be able to send state
- event_auth.check_auth_rules_for_event(
- RoomVersions.V9,
- _random_state_event(creator),
- auth_events,
+ event = _random_state_event(RoomVersions.V9, creator, auth_events)
+ get_awaitable_result(
+ event_auth.check_state_independent_auth_rules(event_store, event)
)
+ event_auth.check_state_dependent_auth_rules(event, auth_events)
# ... but a rejected join_rules event should cause it to be rejected
- rejected_join_rules = _join_rules_event(creator, "public")
+ rejected_join_rules = _join_rules_event(
+ RoomVersions.V9,
+ creator,
+ "public",
+ )
rejected_join_rules.rejected_reason = "stinky"
auth_events.append(rejected_join_rules)
+ event_store.add_event(rejected_join_rules)
- self.assertRaises(
- AuthError,
- event_auth.check_auth_rules_for_event,
+ with self.assertRaises(AuthError):
+ get_awaitable_result(
+ event_auth.check_state_independent_auth_rules(
+ event_store,
+ _random_state_event(RoomVersions.V9, creator),
+ )
+ )
+
+ # ... even if there is *also* a good join rules
+ auth_events.append(_join_rules_event(RoomVersions.V9, creator, "public"))
+ event_store.add_event(rejected_join_rules)
+
+ with self.assertRaises(AuthError):
+ get_awaitable_result(
+ event_auth.check_state_independent_auth_rules(
+ event_store,
+ _random_state_event(RoomVersions.V9, creator),
+ )
+ )
+
+ def test_create_event_with_prev_events(self):
+ """A create event with prev_events should be rejected
+
+ https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
+ 1: If type is m.room.create:
+ 1. If it has any previous events, reject.
+ """
+ creator = f"@creator:{TEST_DOMAIN}"
+
+ # we make both a good event and a bad event, to check that we are rejecting
+ # the bad event for the reason we think we are.
+ good_event = make_event_from_dict(
+ {
+ "room_id": TEST_ROOM_ID,
+ "type": "m.room.create",
+ "state_key": "",
+ "sender": creator,
+ "content": {
+ "creator": creator,
+ "room_version": RoomVersions.V9.identifier,
+ },
+ "auth_events": [],
+ "prev_events": [],
+ },
+ room_version=RoomVersions.V9,
+ )
+ bad_event = make_event_from_dict(
+ {**good_event.get_dict(), "prev_events": ["$fakeevent"]},
+ room_version=RoomVersions.V9,
+ )
+
+ event_store = _StubEventSourceStore()
+
+ get_awaitable_result(
+ event_auth.check_state_independent_auth_rules(event_store, good_event)
+ )
+ with self.assertRaises(AuthError):
+ get_awaitable_result(
+ event_auth.check_state_independent_auth_rules(event_store, bad_event)
+ )
+
+ def test_duplicate_auth_events(self):
+ """Events with duplicate auth_events should be rejected
+
+ https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
+ 2. Reject if event has auth_events that:
+ 1. have duplicate entries for a given type and state_key pair
+ """
+ creator = "@creator:example.com"
+
+ create_event = _create_event(RoomVersions.V9, creator)
+ join_event1 = _join_event(RoomVersions.V9, creator)
+ pl_event = _power_levels_event(
RoomVersions.V9,
- _random_state_event(creator),
- auth_events,
+ creator,
+ {"state_default": 30, "users": {"creator": 100}},
)
- # ... even if there is *also* a good join rules
- auth_events.append(_join_rules_event(creator, "public"))
+ # create a second join event, so that we can make a duplicate
+ join_event2 = _join_event(RoomVersions.V9, creator)
- self.assertRaises(
- AuthError,
- event_auth.check_auth_rules_for_event,
+ event_store = _StubEventSourceStore()
+ event_store.add_events([create_event, join_event1, join_event2, pl_event])
+
+ good_event = _random_state_event(
+ RoomVersions.V9, creator, [create_event, join_event2, pl_event]
+ )
+ bad_event = _random_state_event(
+ RoomVersions.V9, creator, [create_event, join_event1, join_event2, pl_event]
+ )
+ # a variation: two instances of the *same* event
+ bad_event2 = _random_state_event(
+ RoomVersions.V9, creator, [create_event, join_event2, join_event2, pl_event]
+ )
+
+ get_awaitable_result(
+ event_auth.check_state_independent_auth_rules(event_store, good_event)
+ )
+ with self.assertRaises(AuthError):
+ get_awaitable_result(
+ event_auth.check_state_independent_auth_rules(event_store, bad_event)
+ )
+ with self.assertRaises(AuthError):
+ get_awaitable_result(
+ event_auth.check_state_independent_auth_rules(event_store, bad_event2)
+ )
+
+ def test_unexpected_auth_events(self):
+ """Events with excess auth_events should be rejected
+
+ https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
+ 2. Reject if event has auth_events that:
+ 2. have entries whose type and state_key don’t match those specified by the
+ auth events selection algorithm described in the server specification.
+ """
+ creator = "@creator:example.com"
+
+ create_event = _create_event(RoomVersions.V9, creator)
+ join_event = _join_event(RoomVersions.V9, creator)
+ pl_event = _power_levels_event(
RoomVersions.V9,
- _random_state_event(creator),
- auth_events,
+ creator,
+ {"state_default": 30, "users": {"creator": 100}},
)
+ join_rules_event = _join_rules_event(RoomVersions.V9, creator, "public")
+
+ event_store = _StubEventSourceStore()
+ event_store.add_events([create_event, join_event, pl_event, join_rules_event])
+
+ good_event = _random_state_event(
+ RoomVersions.V9, creator, [create_event, join_event, pl_event]
+ )
+ # join rules should *not* be included in the auth events.
+ bad_event = _random_state_event(
+ RoomVersions.V9,
+ creator,
+ [create_event, join_event, pl_event, join_rules_event],
+ )
+
+ get_awaitable_result(
+ event_auth.check_state_independent_auth_rules(event_store, good_event)
+ )
+ with self.assertRaises(AuthError):
+ get_awaitable_result(
+ event_auth.check_state_independent_auth_rules(event_store, bad_event)
+ )
def test_random_users_cannot_send_state_before_first_pl(self):
"""
@@ -73,24 +244,22 @@ class EventAuthTestCase(unittest.TestCase):
creator = "@creator:example.com"
joiner = "@joiner:example.com"
auth_events = [
- _create_event(creator),
- _join_event(creator),
- _join_event(joiner),
+ _create_event(RoomVersions.V1, creator),
+ _join_event(RoomVersions.V1, creator),
+ _join_event(RoomVersions.V1, joiner),
]
# creator should be able to send state
- event_auth.check_auth_rules_for_event(
- RoomVersions.V1,
- _random_state_event(creator),
+ event_auth.check_state_dependent_auth_rules(
+ _random_state_event(RoomVersions.V1, creator),
auth_events,
)
# joiner should not be able to send state
self.assertRaises(
AuthError,
- event_auth.check_auth_rules_for_event,
- RoomVersions.V1,
- _random_state_event(joiner),
+ event_auth.check_state_dependent_auth_rules,
+ _random_state_event(RoomVersions.V1, joiner),
auth_events,
)
@@ -104,28 +273,28 @@ class EventAuthTestCase(unittest.TestCase):
king = "@joiner2:example.com"
auth_events = [
- _create_event(creator),
- _join_event(creator),
+ _create_event(RoomVersions.V1, creator),
+ _join_event(RoomVersions.V1, creator),
_power_levels_event(
- creator, {"state_default": "30", "users": {pleb: "29", king: "30"}}
+ RoomVersions.V1,
+ creator,
+ {"state_default": "30", "users": {pleb: "29", king: "30"}},
),
- _join_event(pleb),
- _join_event(king),
+ _join_event(RoomVersions.V1, pleb),
+ _join_event(RoomVersions.V1, king),
]
# pleb should not be able to send state
self.assertRaises(
AuthError,
- event_auth.check_auth_rules_for_event,
- RoomVersions.V1,
- _random_state_event(pleb),
+ event_auth.check_state_dependent_auth_rules,
+ _random_state_event(RoomVersions.V1, pleb),
auth_events,
),
# king should be able to send state
- event_auth.check_auth_rules_for_event(
- RoomVersions.V1,
- _random_state_event(king),
+ event_auth.check_state_dependent_auth_rules(
+ _random_state_event(RoomVersions.V1, king),
auth_events,
)
@@ -134,37 +303,33 @@ class EventAuthTestCase(unittest.TestCase):
creator = "@creator:example.com"
other = "@other:example.com"
auth_events = [
- _create_event(creator),
- _join_event(creator),
+ _create_event(RoomVersions.V1, creator),
+ _join_event(RoomVersions.V1, creator),
]
# creator should be able to send aliases
- event_auth.check_auth_rules_for_event(
- RoomVersions.V1,
- _alias_event(creator),
+ event_auth.check_state_dependent_auth_rules(
+ _alias_event(RoomVersions.V1, creator),
auth_events,
)
# Reject an event with no state key.
with self.assertRaises(AuthError):
- event_auth.check_auth_rules_for_event(
- RoomVersions.V1,
- _alias_event(creator, state_key=""),
+ event_auth.check_state_dependent_auth_rules(
+ _alias_event(RoomVersions.V1, creator, state_key=""),
auth_events,
)
# If the domain of the sender does not match the state key, reject.
with self.assertRaises(AuthError):
- event_auth.check_auth_rules_for_event(
- RoomVersions.V1,
- _alias_event(creator, state_key="test.com"),
+ event_auth.check_state_dependent_auth_rules(
+ _alias_event(RoomVersions.V1, creator, state_key="test.com"),
auth_events,
)
# Note that the member does *not* need to be in the room.
- event_auth.check_auth_rules_for_event(
- RoomVersions.V1,
- _alias_event(other),
+ event_auth.check_state_dependent_auth_rules(
+ _alias_event(RoomVersions.V1, other),
auth_events,
)
@@ -173,38 +338,35 @@ class EventAuthTestCase(unittest.TestCase):
creator = "@creator:example.com"
other = "@other:example.com"
auth_events = [
- _create_event(creator),
- _join_event(creator),
+ _create_event(RoomVersions.V6, creator),
+ _join_event(RoomVersions.V6, creator),
]
# creator should be able to send aliases
- event_auth.check_auth_rules_for_event(
- RoomVersions.V6,
- _alias_event(creator),
+ event_auth.check_state_dependent_auth_rules(
+ _alias_event(RoomVersions.V6, creator),
auth_events,
)
# No particular checks are done on the state key.
- event_auth.check_auth_rules_for_event(
- RoomVersions.V6,
- _alias_event(creator, state_key=""),
+ event_auth.check_state_dependent_auth_rules(
+ _alias_event(RoomVersions.V6, creator, state_key=""),
auth_events,
)
- event_auth.check_auth_rules_for_event(
- RoomVersions.V6,
- _alias_event(creator, state_key="test.com"),
+ event_auth.check_state_dependent_auth_rules(
+ _alias_event(RoomVersions.V6, creator, state_key="test.com"),
auth_events,
)
# Per standard auth rules, the member must be in the room.
with self.assertRaises(AuthError):
- event_auth.check_auth_rules_for_event(
- RoomVersions.V6,
- _alias_event(other),
+ event_auth.check_state_dependent_auth_rules(
+ _alias_event(RoomVersions.V6, other),
auth_events,
)
- def test_msc2209(self):
+ @parameterized.expand([(RoomVersions.V1, True), (RoomVersions.V6, False)])
+ def test_notifications(self, room_version: RoomVersion, allow_modification: bool):
"""
Notifications power levels get checked due to MSC2209.
"""
@@ -212,28 +374,26 @@ class EventAuthTestCase(unittest.TestCase):
pleb = "@joiner:example.com"
auth_events = [
- _create_event(creator),
- _join_event(creator),
+ _create_event(room_version, creator),
+ _join_event(room_version, creator),
_power_levels_event(
- creator, {"state_default": "30", "users": {pleb: "30"}}
+ room_version, creator, {"state_default": "30", "users": {pleb: "30"}}
),
- _join_event(pleb),
+ _join_event(room_version, pleb),
]
- # pleb should be able to modify the notifications power level.
- event_auth.check_auth_rules_for_event(
- RoomVersions.V1,
- _power_levels_event(pleb, {"notifications": {"room": 100}}),
- auth_events,
+ pl_event = _power_levels_event(
+ room_version, pleb, {"notifications": {"room": 100}}
)
- # But an MSC2209 room rejects this change.
- with self.assertRaises(AuthError):
- event_auth.check_auth_rules_for_event(
- RoomVersions.V6,
- _power_levels_event(pleb, {"notifications": {"room": 100}}),
- auth_events,
- )
+ # on room V1, pleb should be able to modify the notifications power level.
+ if allow_modification:
+ event_auth.check_state_dependent_auth_rules(pl_event, auth_events)
+
+ else:
+ # But an MSC2209 room rejects this change.
+ with self.assertRaises(AuthError):
+ event_auth.check_state_dependent_auth_rules(pl_event, auth_events)
def test_join_rules_public(self):
"""
@@ -243,58 +403,60 @@ class EventAuthTestCase(unittest.TestCase):
pleb = "@joiner:example.com"
auth_events = {
- ("m.room.create", ""): _create_event(creator),
- ("m.room.member", creator): _join_event(creator),
- ("m.room.join_rules", ""): _join_rules_event(creator, "public"),
+ ("m.room.create", ""): _create_event(RoomVersions.V6, creator),
+ ("m.room.member", creator): _join_event(RoomVersions.V6, creator),
+ ("m.room.join_rules", ""): _join_rules_event(
+ RoomVersions.V6, creator, "public"
+ ),
}
# Check join.
- event_auth.check_auth_rules_for_event(
- RoomVersions.V6,
- _join_event(pleb),
+ event_auth.check_state_dependent_auth_rules(
+ _join_event(RoomVersions.V6, pleb),
auth_events.values(),
)
# A user cannot be force-joined to a room.
with self.assertRaises(AuthError):
- event_auth.check_auth_rules_for_event(
- RoomVersions.V6,
- _member_event(pleb, "join", sender=creator),
+ event_auth.check_state_dependent_auth_rules(
+ _member_event(RoomVersions.V6, pleb, "join", sender=creator),
auth_events.values(),
)
# Banned should be rejected.
- auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
+ auth_events[("m.room.member", pleb)] = _member_event(
+ RoomVersions.V6, pleb, "ban"
+ )
with self.assertRaises(AuthError):
- event_auth.check_auth_rules_for_event(
- RoomVersions.V6,
- _join_event(pleb),
+ event_auth.check_state_dependent_auth_rules(
+ _join_event(RoomVersions.V6, pleb),
auth_events.values(),
)
# A user who left can re-join.
- auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
- event_auth.check_auth_rules_for_event(
- RoomVersions.V6,
- _join_event(pleb),
+ auth_events[("m.room.member", pleb)] = _member_event(
+ RoomVersions.V6, pleb, "leave"
+ )
+ event_auth.check_state_dependent_auth_rules(
+ _join_event(RoomVersions.V6, pleb),
auth_events.values(),
)
# A user can send a join if they're in the room.
- auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
- event_auth.check_auth_rules_for_event(
- RoomVersions.V6,
- _join_event(pleb),
+ auth_events[("m.room.member", pleb)] = _member_event(
+ RoomVersions.V6, pleb, "join"
+ )
+ event_auth.check_state_dependent_auth_rules(
+ _join_event(RoomVersions.V6, pleb),
auth_events.values(),
)
# A user can accept an invite.
auth_events[("m.room.member", pleb)] = _member_event(
- pleb, "invite", sender=creator
+ RoomVersions.V6, pleb, "invite", sender=creator
)
- event_auth.check_auth_rules_for_event(
- RoomVersions.V6,
- _join_event(pleb),
+ event_auth.check_state_dependent_auth_rules(
+ _join_event(RoomVersions.V6, pleb),
auth_events.values(),
)
@@ -306,64 +468,88 @@ class EventAuthTestCase(unittest.TestCase):
pleb = "@joiner:example.com"
auth_events = {
- ("m.room.create", ""): _create_event(creator),
- ("m.room.member", creator): _join_event(creator),
- ("m.room.join_rules", ""): _join_rules_event(creator, "invite"),
+ ("m.room.create", ""): _create_event(RoomVersions.V6, creator),
+ ("m.room.member", creator): _join_event(RoomVersions.V6, creator),
+ ("m.room.join_rules", ""): _join_rules_event(
+ RoomVersions.V6, creator, "invite"
+ ),
}
# A join without an invite is rejected.
with self.assertRaises(AuthError):
- event_auth.check_auth_rules_for_event(
- RoomVersions.V6,
- _join_event(pleb),
+ event_auth.check_state_dependent_auth_rules(
+ _join_event(RoomVersions.V6, pleb),
auth_events.values(),
)
# A user cannot be force-joined to a room.
with self.assertRaises(AuthError):
- event_auth.check_auth_rules_for_event(
- RoomVersions.V6,
- _member_event(pleb, "join", sender=creator),
+ event_auth.check_state_dependent_auth_rules(
+ _member_event(RoomVersions.V6, pleb, "join", sender=creator),
auth_events.values(),
)
# Banned should be rejected.
- auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
+ auth_events[("m.room.member", pleb)] = _member_event(
+ RoomVersions.V6, pleb, "ban"
+ )
with self.assertRaises(AuthError):
- event_auth.check_auth_rules_for_event(
- RoomVersions.V6,
- _join_event(pleb),
+ event_auth.check_state_dependent_auth_rules(
+ _join_event(RoomVersions.V6, pleb),
auth_events.values(),
)
# A user who left cannot re-join.
- auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
+ auth_events[("m.room.member", pleb)] = _member_event(
+ RoomVersions.V6, pleb, "leave"
+ )
with self.assertRaises(AuthError):
- event_auth.check_auth_rules_for_event(
- RoomVersions.V6,
- _join_event(pleb),
+ event_auth.check_state_dependent_auth_rules(
+ _join_event(RoomVersions.V6, pleb),
auth_events.values(),
)
# A user can send a join if they're in the room.
- auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
- event_auth.check_auth_rules_for_event(
- RoomVersions.V6,
- _join_event(pleb),
+ auth_events[("m.room.member", pleb)] = _member_event(
+ RoomVersions.V6, pleb, "join"
+ )
+ event_auth.check_state_dependent_auth_rules(
+ _join_event(RoomVersions.V6, pleb),
auth_events.values(),
)
# A user can accept an invite.
auth_events[("m.room.member", pleb)] = _member_event(
- pleb, "invite", sender=creator
+ RoomVersions.V6, pleb, "invite", sender=creator
)
- event_auth.check_auth_rules_for_event(
- RoomVersions.V6,
- _join_event(pleb),
+ event_auth.check_state_dependent_auth_rules(
+ _join_event(RoomVersions.V6, pleb),
auth_events.values(),
)
- def test_join_rules_msc3083_restricted(self):
+ def test_join_rules_restricted_old_room(self) -> None:
+ """Old room versions should reject joins to restricted rooms"""
+ creator = "@creator:example.com"
+ pleb = "@joiner:example.com"
+
+ auth_events = {
+ ("m.room.create", ""): _create_event(RoomVersions.V6, creator),
+ ("m.room.member", creator): _join_event(RoomVersions.V6, creator),
+ ("m.room.power_levels", ""): _power_levels_event(
+ RoomVersions.V6, creator, {"invite": 0}
+ ),
+ ("m.room.join_rules", ""): _join_rules_event(
+ RoomVersions.V6, creator, "restricted"
+ ),
+ }
+
+ with self.assertRaises(AuthError):
+ event_auth.check_state_dependent_auth_rules(
+ _join_event(RoomVersions.V6, pleb),
+ auth_events.values(),
+ )
+
+ def test_join_rules_msc3083_restricted(self) -> None:
"""
Test joining a restricted room from MSC3083.
@@ -377,29 +563,25 @@ class EventAuthTestCase(unittest.TestCase):
pleb = "@joiner:example.com"
auth_events = {
- ("m.room.create", ""): _create_event(creator),
- ("m.room.member", creator): _join_event(creator),
- ("m.room.power_levels", ""): _power_levels_event(creator, {"invite": 0}),
- ("m.room.join_rules", ""): _join_rules_event(creator, "restricted"),
+ ("m.room.create", ""): _create_event(RoomVersions.V8, creator),
+ ("m.room.member", creator): _join_event(RoomVersions.V8, creator),
+ ("m.room.power_levels", ""): _power_levels_event(
+ RoomVersions.V8, creator, {"invite": 0}
+ ),
+ ("m.room.join_rules", ""): _join_rules_event(
+ RoomVersions.V8, creator, "restricted"
+ ),
}
- # Older room versions don't understand this join rule
- with self.assertRaises(AuthError):
- event_auth.check_auth_rules_for_event(
- RoomVersions.V6,
- _join_event(pleb),
- auth_events.values(),
- )
-
# A properly formatted join event should work.
authorised_join_event = _join_event(
+ RoomVersions.V8,
pleb,
additional_content={
EventContentFields.AUTHORISING_USER: "@creator:example.com"
},
)
- event_auth.check_auth_rules_for_event(
- RoomVersions.V8,
+ event_auth.check_state_dependent_auth_rules(
authorised_join_event,
auth_events.values(),
)
@@ -408,14 +590,16 @@ class EventAuthTestCase(unittest.TestCase):
# are done properly).
pl_auth_events = auth_events.copy()
pl_auth_events[("m.room.power_levels", "")] = _power_levels_event(
- creator, {"invite": 100, "users": {"@inviter:foo.test": 150}}
+ RoomVersions.V8,
+ creator,
+ {"invite": 100, "users": {"@inviter:foo.test": 150}},
)
pl_auth_events[("m.room.member", "@inviter:foo.test")] = _join_event(
- "@inviter:foo.test"
+ RoomVersions.V8, "@inviter:foo.test"
)
- event_auth.check_auth_rules_for_event(
- RoomVersions.V8,
+ event_auth.check_state_dependent_auth_rules(
_join_event(
+ RoomVersions.V8,
pleb,
additional_content={
EventContentFields.AUTHORISING_USER: "@inviter:foo.test"
@@ -426,21 +610,22 @@ class EventAuthTestCase(unittest.TestCase):
# A join which is missing an authorised server is rejected.
with self.assertRaises(AuthError):
- event_auth.check_auth_rules_for_event(
- RoomVersions.V8,
- _join_event(pleb),
+ event_auth.check_state_dependent_auth_rules(
+ _join_event(RoomVersions.V8, pleb),
auth_events.values(),
)
# An join authorised by a user who is not in the room is rejected.
pl_auth_events = auth_events.copy()
pl_auth_events[("m.room.power_levels", "")] = _power_levels_event(
- creator, {"invite": 100, "users": {"@other:example.com": 150}}
+ RoomVersions.V8,
+ creator,
+ {"invite": 100, "users": {"@other:example.com": 150}},
)
with self.assertRaises(AuthError):
- event_auth.check_auth_rules_for_event(
- RoomVersions.V8,
+ event_auth.check_state_dependent_auth_rules(
_join_event(
+ RoomVersions.V8,
pleb,
additional_content={
EventContentFields.AUTHORISING_USER: "@other:example.com"
@@ -452,9 +637,9 @@ class EventAuthTestCase(unittest.TestCase):
# A user cannot be force-joined to a room. (This uses an event which
# *would* be valid, but is sent be a different user.)
with self.assertRaises(AuthError):
- event_auth.check_auth_rules_for_event(
- RoomVersions.V8,
+ event_auth.check_state_dependent_auth_rules(
_member_event(
+ RoomVersions.V8,
pleb,
"join",
sender=creator,
@@ -466,62 +651,109 @@ class EventAuthTestCase(unittest.TestCase):
)
# Banned should be rejected.
- auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
+ auth_events[("m.room.member", pleb)] = _member_event(
+ RoomVersions.V8, pleb, "ban"
+ )
with self.assertRaises(AuthError):
- event_auth.check_auth_rules_for_event(
- RoomVersions.V8,
+ event_auth.check_state_dependent_auth_rules(
authorised_join_event,
auth_events.values(),
)
# A user who left can re-join.
- auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
- event_auth.check_auth_rules_for_event(
- RoomVersions.V8,
+ auth_events[("m.room.member", pleb)] = _member_event(
+ RoomVersions.V8, pleb, "leave"
+ )
+ event_auth.check_state_dependent_auth_rules(
authorised_join_event,
auth_events.values(),
)
# A user can send a join if they're in the room. (This doesn't need to
# be authorised since the user is already joined.)
- auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
- event_auth.check_auth_rules_for_event(
- RoomVersions.V8,
- _join_event(pleb),
+ auth_events[("m.room.member", pleb)] = _member_event(
+ RoomVersions.V8, pleb, "join"
+ )
+ event_auth.check_state_dependent_auth_rules(
+ _join_event(RoomVersions.V8, pleb),
auth_events.values(),
)
# A user can accept an invite. (This doesn't need to be authorised since
# the user was invited.)
auth_events[("m.room.member", pleb)] = _member_event(
- pleb, "invite", sender=creator
+ RoomVersions.V8, pleb, "invite", sender=creator
)
- event_auth.check_auth_rules_for_event(
- RoomVersions.V8,
- _join_event(pleb),
+ event_auth.check_state_dependent_auth_rules(
+ _join_event(RoomVersions.V8, pleb),
auth_events.values(),
)
+ def test_room_v10_rejects_string_power_levels(self) -> None:
+ pl_event_content = {"users_default": "42"}
+ pl_event = make_event_from_dict(
+ {
+ "room_id": TEST_ROOM_ID,
+ **_maybe_get_event_id_dict_for_room_version(RoomVersions.V10),
+ "type": "m.room.power_levels",
+ "sender": "@test:test.com",
+ "state_key": "",
+ "content": pl_event_content,
+ "signatures": {"test.com": {"ed25519:0": "some9signature"}},
+ },
+ room_version=RoomVersions.V10,
+ )
-# helpers for making events
+ pl_event2_content = {"events": {"m.room.name": "42", "m.room.power_levels": 42}}
+ pl_event2 = make_event_from_dict(
+ {
+ "room_id": TEST_ROOM_ID,
+ **_maybe_get_event_id_dict_for_room_version(RoomVersions.V10),
+ "type": "m.room.power_levels",
+ "sender": "@test:test.com",
+ "state_key": "",
+ "content": pl_event2_content,
+ "signatures": {"test.com": {"ed25519:0": "some9signature"}},
+ },
+ room_version=RoomVersions.V10,
+ )
+
+ with self.assertRaises(SynapseError):
+ event_auth._check_power_levels(
+ pl_event.room_version, pl_event, {("fake_type", "fake_key"): pl_event2}
+ )
+
+ with self.assertRaises(SynapseError):
+ event_auth._check_power_levels(
+ pl_event.room_version, pl_event2, {("fake_type", "fake_key"): pl_event}
+ )
-TEST_ROOM_ID = "!test:room"
+
+# helpers for making events
+TEST_DOMAIN = "example.com"
+TEST_ROOM_ID = f"!test_room:{TEST_DOMAIN}"
-def _create_event(user_id: str) -> EventBase:
+def _create_event(
+ room_version: RoomVersion,
+ user_id: str,
+) -> EventBase:
return make_event_from_dict(
{
"room_id": TEST_ROOM_ID,
- "event_id": _get_event_id(),
+ **_maybe_get_event_id_dict_for_room_version(room_version),
"type": "m.room.create",
"state_key": "",
"sender": user_id,
"content": {"creator": user_id},
- }
+ "auth_events": [],
+ },
+ room_version=room_version,
)
def _member_event(
+ room_version: RoomVersion,
user_id: str,
membership: str,
sender: Optional[str] = None,
@@ -530,79 +762,119 @@ def _member_event(
return make_event_from_dict(
{
"room_id": TEST_ROOM_ID,
- "event_id": _get_event_id(),
+ **_maybe_get_event_id_dict_for_room_version(room_version),
"type": "m.room.member",
"sender": sender or user_id,
"state_key": user_id,
"content": {"membership": membership, **(additional_content or {})},
+ "auth_events": [],
"prev_events": [],
- }
+ },
+ room_version=room_version,
)
-def _join_event(user_id: str, additional_content: Optional[dict] = None) -> EventBase:
- return _member_event(user_id, "join", additional_content=additional_content)
+def _join_event(
+ room_version: RoomVersion,
+ user_id: str,
+ additional_content: Optional[dict] = None,
+) -> EventBase:
+ return _member_event(
+ room_version,
+ user_id,
+ "join",
+ additional_content=additional_content,
+ )
-def _power_levels_event(sender: str, content: JsonDict) -> EventBase:
+def _power_levels_event(
+ room_version: RoomVersion,
+ sender: str,
+ content: JsonDict,
+) -> EventBase:
return make_event_from_dict(
{
"room_id": TEST_ROOM_ID,
- "event_id": _get_event_id(),
+ **_maybe_get_event_id_dict_for_room_version(room_version),
"type": "m.room.power_levels",
"sender": sender,
"state_key": "",
"content": content,
- }
+ },
+ room_version=room_version,
)
-def _alias_event(sender: str, **kwargs) -> EventBase:
+def _alias_event(room_version: RoomVersion, sender: str, **kwargs) -> EventBase:
data = {
"room_id": TEST_ROOM_ID,
- "event_id": _get_event_id(),
+ **_maybe_get_event_id_dict_for_room_version(room_version),
"type": "m.room.aliases",
"sender": sender,
"state_key": get_domain_from_id(sender),
"content": {"aliases": []},
}
data.update(**kwargs)
- return make_event_from_dict(data)
+ return make_event_from_dict(data, room_version=room_version)
+
+def _build_auth_dict_for_room_version(
+ room_version: RoomVersion, auth_events: Iterable[EventBase]
+) -> List:
+ if room_version.event_format == EventFormatVersions.V1:
+ return [(e.event_id, "not_used") for e in auth_events]
+ else:
+ return [e.event_id for e in auth_events]
-def _random_state_event(sender: str) -> EventBase:
+
+def _random_state_event(
+ room_version: RoomVersion,
+ sender: str,
+ auth_events: Optional[Iterable[EventBase]] = None,
+) -> EventBase:
+ if auth_events is None:
+ auth_events = []
return make_event_from_dict(
{
"room_id": TEST_ROOM_ID,
- "event_id": _get_event_id(),
+ **_maybe_get_event_id_dict_for_room_version(room_version),
"type": "test.state",
"sender": sender,
"state_key": "",
"content": {"membership": "join"},
- }
+ "auth_events": _build_auth_dict_for_room_version(room_version, auth_events),
+ },
+ room_version=room_version,
)
-def _join_rules_event(sender: str, join_rule: str) -> EventBase:
+def _join_rules_event(
+ room_version: RoomVersion, sender: str, join_rule: str
+) -> EventBase:
return make_event_from_dict(
{
"room_id": TEST_ROOM_ID,
- "event_id": _get_event_id(),
+ **_maybe_get_event_id_dict_for_room_version(room_version),
"type": "m.room.join_rules",
"sender": sender,
"state_key": "",
"content": {
"join_rule": join_rule,
},
- }
+ },
+ room_version=room_version,
)
event_count = 0
-def _get_event_id() -> str:
+def _maybe_get_event_id_dict_for_room_version(room_version: RoomVersion) -> dict:
+ """If this room version needs it, generate an event id"""
+ if room_version.event_format != EventFormatVersions.V1:
+ return {}
+
global event_count
c = event_count
event_count += 1
- return "!%i:example.com" % (c,)
+ return {"event_id": "!%i:example.com" % (c,)}
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 0cbef70bfa..779fad1f63 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -81,12 +81,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
self.handler = self.homeserver.get_federation_handler()
federation_event_handler = self.homeserver.get_federation_event_handler()
- async def _check_event_auth(
- origin,
- event,
- context,
- ):
- return context
+ async def _check_event_auth(origin, event, context):
+ pass
federation_event_handler._check_event_auth = _check_event_auth
self.client = self.homeserver.get_federation_client()
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index b4574b2ffe..1a70eddc9b 100644
--- a/tests/test_metrics.py
+++ b/tests/test_metrics.py
@@ -12,7 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+try:
+ from importlib import metadata
+except ImportError:
+ import importlib_metadata as metadata # type: ignore[no-redef]
+from unittest.mock import patch
+
+from pkg_resources import parse_version
+
+from synapse.app._base import _set_prometheus_client_use_created_metrics
from synapse.metrics import REGISTRY, InFlightGauge, generate_latest
from synapse.util.caches.deferred_cache import DeferredCache
@@ -162,3 +171,30 @@ class CacheMetricsTests(unittest.HomeserverTestCase):
self.assertEqual(items["synapse_util_caches_cache_size"], "1.0")
self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0")
+
+
+class PrometheusMetricsHackTestCase(unittest.HomeserverTestCase):
+ if parse_version(metadata.version("prometheus_client")) < parse_version("0.14.0"):
+ skip = "prometheus-client too old"
+
+ def test_created_metrics_disabled(self) -> None:
+ """
+ Tests that a brittle hack, to disable `_created` metrics, works.
+ This involves poking at the internals of prometheus-client.
+ It's not the end of the world if this doesn't work.
+
+ This test gives us a way to notice if prometheus-client changes
+ their internals.
+ """
+ import prometheus_client.metrics
+
+ PRIVATE_FLAG_NAME = "_use_created"
+
+ # By default, the pesky `_created` metrics are enabled.
+ # Check this assumption is still valid.
+ self.assertTrue(getattr(prometheus_client.metrics, PRIVATE_FLAG_NAME))
+
+ with patch("prometheus_client.metrics") as mock:
+ setattr(mock, PRIVATE_FLAG_NAME, True)
+ _set_prometheus_client_use_created_metrics(False)
+ self.assertFalse(getattr(mock, PRIVATE_FLAG_NAME, False))
diff --git a/tests/test_server.py b/tests/test_server.py
index 0f1eb43cbc..23975d59c3 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -14,7 +14,7 @@
import re
from http import HTTPStatus
-from typing import Tuple
+from typing import Awaitable, Callable, Dict, NoReturn, Optional, Tuple
from twisted.internet.defer import Deferred
from twisted.web.resource import Resource
@@ -26,16 +26,17 @@ from synapse.http.server import (
DirectServeJsonResource,
JsonResource,
OptionsResource,
- cancellable,
)
from synapse.http.site import SynapseRequest, SynapseSite
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict
from synapse.util import Clock
+from synapse.util.cancellation import cancellable
from tests import unittest
-from tests.http.server._base import EndpointCancellationTestHelperMixin
+from tests.http.server._base import test_disconnect
from tests.server import (
+ FakeChannel,
FakeSite,
ThreadedMemoryReactorClock,
make_request,
@@ -44,7 +45,7 @@ from tests.server import (
class JsonResourceTests(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.reactor = ThreadedMemoryReactorClock()
self.hs_clock = Clock(self.reactor)
self.homeserver = setup_test_homeserver(
@@ -54,7 +55,7 @@ class JsonResourceTests(unittest.TestCase):
reactor=self.reactor,
)
- def test_handler_for_request(self):
+ def test_handler_for_request(self) -> None:
"""
JsonResource.handler_for_request gives correctly decoded URL args to
the callback, while Twisted will give the raw bytes of URL query
@@ -62,7 +63,9 @@ class JsonResourceTests(unittest.TestCase):
"""
got_kwargs = {}
- def _callback(request, **kwargs):
+ def _callback(
+ request: SynapseRequest, **kwargs: object
+ ) -> Tuple[int, Dict[str, object]]:
got_kwargs.update(kwargs)
return 200, kwargs
@@ -83,13 +86,13 @@ class JsonResourceTests(unittest.TestCase):
self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"})
- def test_callback_direct_exception(self):
+ def test_callback_direct_exception(self) -> None:
"""
If the web callback raises an uncaught exception, it will be translated
into a 500.
"""
- def _callback(request, **kwargs):
+ def _callback(request: SynapseRequest, **kwargs: object) -> NoReturn:
raise Exception("boo")
res = JsonResource(self.homeserver)
@@ -101,19 +104,19 @@ class JsonResourceTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
)
- self.assertEqual(channel.result["code"], b"500")
+ self.assertEqual(channel.code, 500)
- def test_callback_indirect_exception(self):
+ def test_callback_indirect_exception(self) -> None:
"""
If the web callback raises an uncaught exception in a Deferred, it will
be translated into a 500.
"""
- def _throw(*args):
+ def _throw(*args: object) -> NoReturn:
raise Exception("boo")
- def _callback(request, **kwargs):
- d = Deferred()
+ def _callback(request: SynapseRequest, **kwargs: object) -> "Deferred[None]":
+ d: "Deferred[None]" = Deferred()
d.addCallback(_throw)
self.reactor.callLater(0.5, d.callback, True)
return make_deferred_yieldable(d)
@@ -127,15 +130,15 @@ class JsonResourceTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
)
- self.assertEqual(channel.result["code"], b"500")
+ self.assertEqual(channel.code, 500)
- def test_callback_synapseerror(self):
+ def test_callback_synapseerror(self) -> None:
"""
If the web callback raises a SynapseError, it returns the appropriate
status code and message set in it.
"""
- def _callback(request, **kwargs):
+ def _callback(request: SynapseRequest, **kwargs: object) -> NoReturn:
raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN)
res = JsonResource(self.homeserver)
@@ -147,16 +150,16 @@ class JsonResourceTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
)
- self.assertEqual(channel.result["code"], b"403")
+ self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
- def test_no_handler(self):
+ def test_no_handler(self) -> None:
"""
If there is no handler to process the request, Synapse will return 400.
"""
- def _callback(request, **kwargs):
+ def _callback(request: SynapseRequest, **kwargs: object) -> None:
"""
Not ever actually called!
"""
@@ -171,18 +174,20 @@ class JsonResourceTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foobar"
)
- self.assertEqual(channel.result["code"], b"400")
+ self.assertEqual(channel.code, 400)
self.assertEqual(channel.json_body["error"], "Unrecognized request")
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
- def test_head_request(self):
+ def test_head_request(self) -> None:
"""
JsonResource.handler_for_request gives correctly decoded URL args to
the callback, while Twisted will give the raw bytes of URL query
arguments.
"""
- def _callback(request, **kwargs):
+ def _callback(
+ request: SynapseRequest, **kwargs: object
+ ) -> Tuple[int, Dict[str, object]]:
return 200, {"result": True}
res = JsonResource(self.homeserver)
@@ -198,25 +203,26 @@ class JsonResourceTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/_matrix/foo"
)
- self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.code, 200)
self.assertNotIn("body", channel.result)
class OptionsResourceTests(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.reactor = ThreadedMemoryReactorClock()
class DummyResource(Resource):
isLeaf = True
- def render(self, request):
- return request.path
+ def render(self, request: SynapseRequest) -> bytes:
+ # Type-ignore: mypy thinks request.path is Optional[Any], not bytes.
+ return request.path # type: ignore[return-value]
# Setup a resource with some children.
self.resource = OptionsResource()
self.resource.putChild(b"res", DummyResource())
- def _make_request(self, method, path):
+ def _make_request(self, method: bytes, path: bytes) -> FakeChannel:
"""Create a request from the method/path and return a channel with the response."""
# Create a site and query for the resource.
site = SynapseSite(
@@ -225,7 +231,7 @@ class OptionsResourceTests(unittest.TestCase):
parse_listener_def({"type": "http", "port": 0}),
self.resource,
"1.0",
- max_request_body_size=1234,
+ max_request_body_size=4096,
reactor=self.reactor,
)
@@ -233,10 +239,10 @@ class OptionsResourceTests(unittest.TestCase):
channel = make_request(self.reactor, site, method, path, shorthand=False)
return channel
- def test_unknown_options_request(self):
+ def test_unknown_options_request(self) -> None:
"""An OPTIONS requests to an unknown URL still returns 204 No Content."""
channel = self._make_request(b"OPTIONS", b"/foo/")
- self.assertEqual(channel.result["code"], b"204")
+ self.assertEqual(channel.code, 204)
self.assertNotIn("body", channel.result)
# Ensure the correct CORS headers have been added
@@ -253,10 +259,10 @@ class OptionsResourceTests(unittest.TestCase):
"has CORS Headers header",
)
- def test_known_options_request(self):
+ def test_known_options_request(self) -> None:
"""An OPTIONS requests to an known URL still returns 204 No Content."""
channel = self._make_request(b"OPTIONS", b"/res/")
- self.assertEqual(channel.result["code"], b"204")
+ self.assertEqual(channel.code, 204)
self.assertNotIn("body", channel.result)
# Ensure the correct CORS headers have been added
@@ -273,30 +279,31 @@ class OptionsResourceTests(unittest.TestCase):
"has CORS Headers header",
)
- def test_unknown_request(self):
+ def test_unknown_request(self) -> None:
"""A non-OPTIONS request to an unknown URL should 404."""
channel = self._make_request(b"GET", b"/foo/")
- self.assertEqual(channel.result["code"], b"404")
+ self.assertEqual(channel.code, 404)
- def test_known_request(self):
+ def test_known_request(self) -> None:
"""A non-OPTIONS request to an known URL should query the proper resource."""
channel = self._make_request(b"GET", b"/res/")
- self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.code, 200)
self.assertEqual(channel.result["body"], b"/res/")
class WrapHtmlRequestHandlerTests(unittest.TestCase):
class TestResource(DirectServeHtmlResource):
- callback = None
+ callback: Optional[Callable[..., Awaitable[None]]]
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: SynapseRequest) -> None:
+ assert self.callback is not None
await self.callback(request)
- def setUp(self):
+ def setUp(self) -> None:
self.reactor = ThreadedMemoryReactorClock()
- def test_good_response(self):
- async def callback(request):
+ def test_good_response(self) -> None:
+ async def callback(request: SynapseRequest) -> None:
request.write(b"response")
request.finish()
@@ -307,17 +314,17 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
)
- self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.code, 200)
body = channel.result["body"]
self.assertEqual(body, b"response")
- def test_redirect_exception(self):
+ def test_redirect_exception(self) -> None:
"""
If the callback raises a RedirectException, it is turned into a 30x
with the right location.
"""
- async def callback(request, **kwargs):
+ async def callback(request: SynapseRequest, **kwargs: object) -> None:
raise RedirectException(b"/look/an/eagle", 301)
res = WrapHtmlRequestHandlerTests.TestResource()
@@ -327,18 +334,18 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
)
- self.assertEqual(channel.result["code"], b"301")
+ self.assertEqual(channel.code, 301)
headers = channel.result["headers"]
location_headers = [v for k, v in headers if k == b"Location"]
self.assertEqual(location_headers, [b"/look/an/eagle"])
- def test_redirect_exception_with_cookie(self):
+ def test_redirect_exception_with_cookie(self) -> None:
"""
If the callback raises a RedirectException which sets a cookie, that is
returned too
"""
- async def callback(request, **kwargs):
+ async def callback(request: SynapseRequest, **kwargs: object) -> NoReturn:
e = RedirectException(b"/no/over/there", 304)
e.cookies.append(b"session=yespls")
raise e
@@ -350,17 +357,17 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
)
- self.assertEqual(channel.result["code"], b"304")
+ self.assertEqual(channel.code, 304)
headers = channel.result["headers"]
location_headers = [v for k, v in headers if k == b"Location"]
self.assertEqual(location_headers, [b"/no/over/there"])
cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
self.assertEqual(cookies_headers, [b"session=yespls"])
- def test_head_request(self):
+ def test_head_request(self) -> None:
"""A head request should work by being turned into a GET request."""
- async def callback(request):
+ async def callback(request: SynapseRequest) -> None:
request.write(b"response")
request.finish()
@@ -371,7 +378,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/path"
)
- self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.code, 200)
self.assertNotIn("body", channel.result)
@@ -407,10 +414,10 @@ class CancellableDirectServeHtmlResource(DirectServeHtmlResource):
return HTTPStatus.OK, b"ok"
-class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMixin):
+class DirectServeJsonResourceCancellationTests(unittest.TestCase):
"""Tests for `DirectServeJsonResource` cancellation."""
- def setUp(self):
+ def setUp(self) -> None:
self.reactor = ThreadedMemoryReactorClock()
self.clock = Clock(self.reactor)
self.resource = CancellableDirectServeJsonResource(self.clock)
@@ -421,7 +428,7 @@ class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMix
channel = make_request(
self.reactor, self.site, "GET", "/sleep", await_result=False
)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
@@ -433,7 +440,7 @@ class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMix
channel = make_request(
self.reactor, self.site, "POST", "/sleep", await_result=False
)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
@@ -441,10 +448,10 @@ class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMix
)
-class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMixin):
+class DirectServeHtmlResourceCancellationTests(unittest.TestCase):
"""Tests for `DirectServeHtmlResource` cancellation."""
- def setUp(self):
+ def setUp(self) -> None:
self.reactor = ThreadedMemoryReactorClock()
self.clock = Clock(self.reactor)
self.resource = CancellableDirectServeHtmlResource(self.clock)
@@ -455,7 +462,7 @@ class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMix
channel = make_request(
self.reactor, self.site, "GET", "/sleep", await_result=False
)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
@@ -467,6 +474,6 @@ class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMix
channel = make_request(
self.reactor, self.site, "POST", "/sleep", await_result=False
)
- self._test_disconnect(
+ test_disconnect(
self.reactor, channel, expect_cancellation=False, expected_body=b"ok"
)
diff --git a/tests/test_state.py b/tests/test_state.py
index 95f81bebae..504530b49a 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Collection, Dict, List, Optional
+from typing import Collection, Dict, List, Optional, cast
from unittest.mock import Mock
from twisted.internet import defer
@@ -21,7 +21,9 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext
-from synapse.state import StateHandler, StateResolutionHandler
+from synapse.state import StateHandler, StateResolutionHandler, _make_state_cache_entry
+from synapse.util import Clock
+from synapse.util.macaroons import MacaroonGenerator
from tests import unittest
@@ -97,6 +99,10 @@ class _DummyStore:
state_group = self._next_group
self._next_group += 1
+ if current_state_ids is None:
+ current_state_ids = dict(self._group_to_state[prev_group])
+ current_state_ids.update(delta_ids)
+
self._group_to_state[state_group] = dict(current_state_ids)
return state_group
@@ -129,7 +135,9 @@ class _DummyStore:
async def get_room_version_id(self, room_id):
return RoomVersions.V1.identifier
- async def get_state_group_for_events(self, event_ids):
+ async def get_state_group_for_events(
+ self, event_ids, await_full_state: bool = True
+ ):
res = {}
for event in event_ids:
res[event] = self._event_to_state_group[event]
@@ -190,13 +198,20 @@ class StateTestCase(unittest.TestCase):
"get_clock",
"get_state_resolution_handler",
"get_account_validity_handler",
+ "get_macaroon_generator",
+ "get_instance_name",
+ "get_simple_http_client",
"hostname",
]
)
+ clock = cast(Clock, MockClock())
hs.config = default_config("tesths", True)
hs.get_datastores.return_value = Mock(main=self.dummy_store)
hs.get_state_handler.return_value = None
- hs.get_clock.return_value = MockClock()
+ hs.get_clock.return_value = clock
+ hs.get_macaroon_generator.return_value = MacaroonGenerator(
+ clock, "tesths", b"verysecret"
+ )
hs.get_auth.return_value = Auth(hs)
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
hs.get_storage_controllers.return_value = storage_controllers
@@ -447,6 +462,7 @@ class StateTestCase(unittest.TestCase):
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in old_state
},
+ partial_state=False,
)
)
@@ -477,6 +493,7 @@ class StateTestCase(unittest.TestCase):
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in old_state
},
+ partial_state=False,
)
)
@@ -749,3 +766,43 @@ class StateTestCase(unittest.TestCase):
result = yield defer.ensureDeferred(self.state.compute_event_context(event))
return result
+
+ def test_make_state_cache_entry(self):
+ "Test that calculating a prev_group and delta is correct"
+
+ new_state = {
+ ("a", ""): "E",
+ ("b", ""): "E",
+ ("c", ""): "E",
+ ("d", ""): "E",
+ }
+
+ # old_state_1 has fewer differences to new_state than old_state_2, but
+ # the delta involves deleting a key, which isn't allowed in the deltas,
+ # so we should pick old_state_2 as the prev_group.
+
+ # `old_state_1` has two differences: `a` and `e`
+ old_state_1 = {
+ ("a", ""): "F",
+ ("b", ""): "E",
+ ("c", ""): "E",
+ ("d", ""): "E",
+ ("e", ""): "E",
+ }
+
+ # `old_state_2` has three differences: `a`, `c` and `d`
+ old_state_2 = {
+ ("a", ""): "F",
+ ("b", ""): "E",
+ ("c", ""): "F",
+ ("d", ""): "F",
+ }
+
+ entry = _make_state_cache_entry(new_state, {1: old_state_1, 2: old_state_2})
+
+ self.assertEqual(entry.prev_group, 2)
+
+ # There are three changes from `old_state_2` to `new_state`
+ self.assertEqual(
+ entry.delta_ids, {("a", ""): "E", ("c", ""): "E", ("d", ""): "E"}
+ )
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index 37fada5c53..abd7459a8c 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactorClock
@@ -51,10 +50,10 @@ class TermsTestCase(unittest.HomeserverTestCase):
def test_ui_auth(self):
# Do a UI auth request
- request_data = json.dumps({"username": "kermit", "password": "monkey"})
+ request_data = {"username": "kermit", "password": "monkey"}
channel = self.make_request(b"POST", self.url, request_data)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, channel.result)
self.assertTrue(channel.json_body is not None)
self.assertIsInstance(channel.json_body["session"], str)
@@ -82,16 +81,14 @@ class TermsTestCase(unittest.HomeserverTestCase):
self.assertDictContainsSubset(channel.json_body["params"], expected_params)
# We have to complete the dummy auth stage before completing the terms stage
- request_data = json.dumps(
- {
- "username": "kermit",
- "password": "monkey",
- "auth": {
- "session": channel.json_body["session"],
- "type": "m.login.dummy",
- },
- }
- )
+ request_data = {
+ "username": "kermit",
+ "password": "monkey",
+ "auth": {
+ "session": channel.json_body["session"],
+ "type": "m.login.dummy",
+ },
+ }
self.registration_handler.check_username = Mock(return_value=True)
@@ -99,25 +96,23 @@ class TermsTestCase(unittest.HomeserverTestCase):
# We don't bother checking that the response is correct - we'll leave that to
# other tests. We just want to make sure we're on the right path.
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, channel.result)
# Finish the UI auth for terms
- request_data = json.dumps(
- {
- "username": "kermit",
- "password": "monkey",
- "auth": {
- "session": channel.json_body["session"],
- "type": "m.login.terms",
- },
- }
- )
+ request_data = {
+ "username": "kermit",
+ "password": "monkey",
+ "auth": {
+ "session": channel.json_body["session"],
+ "type": "m.login.terms",
+ },
+ }
channel = self.make_request(b"POST", self.url, request_data)
# We're interested in getting a response that looks like a successful
# registration, not so much that the details are exactly what we want.
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
self.assertTrue(channel.json_body is not None)
self.assertIsInstance(channel.json_body["user_id"], str)
diff --git a/tests/test_types.py b/tests/test_types.py
index 0b10dae848..d8d82a517e 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -26,10 +26,21 @@ class UserIDTestCase(unittest.HomeserverTestCase):
self.assertEqual("test", user.domain)
self.assertEqual(True, self.hs.is_mine(user))
- def test_pase_empty(self):
+ def test_parse_rejects_empty_id(self):
with self.assertRaises(SynapseError):
UserID.from_string("")
+ def test_parse_rejects_missing_sigil(self):
+ with self.assertRaises(SynapseError):
+ UserID.from_string("alice:example.com")
+
+ def test_parse_rejects_missing_separator(self):
+ with self.assertRaises(SynapseError):
+ UserID.from_string("@alice.example.com")
+
+ def test_validation_rejects_missing_domain(self):
+ self.assertFalse(UserID.is_valid("@alice:"))
+
def test_build(self):
user = UserID("5678efgh", "my.domain")
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index f338af6c36..c385b2f8d4 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -272,7 +272,7 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
"state_key": "@user:test",
"content": {"membership": "invite"},
}
- self.add_hashes_and_signatures(invite_pdu)
+ self.add_hashes_and_signatures_from_other_server(invite_pdu)
invite_event_id = make_event_from_dict(invite_pdu, RoomVersions.V9).event_id
self.get_success(
diff --git a/tests/unittest.py b/tests/unittest.py
index e7f255b4fa..975b0a23a7 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -16,7 +16,6 @@
import gc
import hashlib
import hmac
-import json
import logging
import secrets
import time
@@ -29,6 +28,7 @@ from typing import (
Generic,
Iterable,
List,
+ NoReturn,
Optional,
Tuple,
Type,
@@ -40,7 +40,7 @@ from unittest.mock import Mock, patch
import canonicaljson
import signedjson.key
import unpaddedbase64
-from typing_extensions import Protocol
+from typing_extensions import Concatenate, ParamSpec, Protocol
from twisted.internet.defer import Deferred, ensureDeferred
from twisted.python.failure import Failure
@@ -68,7 +68,7 @@ from synapse.logging.context import (
from synapse.rest import RegisterServletsFunc
from synapse.server import HomeServer
from synapse.storage.keys import FetchKeyResult
-from synapse.types import JsonDict, UserID, create_requester
+from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree
@@ -89,6 +89,10 @@ setup_logging()
TV = TypeVar("TV")
_ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True)
+P = ParamSpec("P")
+R = TypeVar("R")
+S = TypeVar("S")
+
class _TypedFailure(Generic[_ExcType], Protocol):
"""Extension to twisted.Failure, where the 'value' has a certain type."""
@@ -98,7 +102,7 @@ class _TypedFailure(Generic[_ExcType], Protocol):
...
-def around(target):
+def around(target: TV) -> Callable[[Callable[Concatenate[S, P], R]], None]:
"""A CLOS-style 'around' modifier, which wraps the original method of the
given instance with another piece of code.
@@ -107,11 +111,11 @@ def around(target):
return orig(*args, **kwargs)
"""
- def _around(code):
+ def _around(code: Callable[Concatenate[S, P], R]) -> None:
name = code.__name__
orig = getattr(target, name)
- def new(*args, **kwargs):
+ def new(*args: P.args, **kwargs: P.kwargs) -> R:
return code(orig, *args, **kwargs)
setattr(target, name, new)
@@ -132,7 +136,7 @@ class TestCase(unittest.TestCase):
level = getattr(method, "loglevel", getattr(self, "loglevel", None))
@around(self)
- def setUp(orig):
+ def setUp(orig: Callable[[], R]) -> R:
# if we're not starting in the sentinel logcontext, then to be honest
# all future bets are off.
if current_context():
@@ -145,7 +149,7 @@ class TestCase(unittest.TestCase):
if level is not None and old_level != level:
@around(self)
- def tearDown(orig):
+ def tearDown(orig: Callable[[], R]) -> R:
ret = orig()
logging.getLogger().setLevel(old_level)
return ret
@@ -159,7 +163,7 @@ class TestCase(unittest.TestCase):
return orig()
@around(self)
- def tearDown(orig):
+ def tearDown(orig: Callable[[], R]) -> R:
ret = orig()
# force a GC to workaround problems with deferreds leaking logcontexts when
# they are GCed (see the logcontext docs)
@@ -168,7 +172,7 @@ class TestCase(unittest.TestCase):
return ret
- def assertObjectHasAttributes(self, attrs, obj):
+ def assertObjectHasAttributes(self, attrs: Dict[str, object], obj: object) -> None:
"""Asserts that the given object has each of the attributes given, and
that the value of each matches according to assertEqual."""
for key in attrs.keys():
@@ -179,12 +183,12 @@ class TestCase(unittest.TestCase):
except AssertionError as e:
raise (type(e))(f"Assert error for '.{key}':") from e
- def assert_dict(self, required, actual):
+ def assert_dict(self, required: dict, actual: dict) -> None:
"""Does a partial assert of a dict.
Args:
- required (dict): The keys and value which MUST be in 'actual'.
- actual (dict): The test result. Extra keys will not be checked.
+ required: The keys and value which MUST be in 'actual'.
+ actual: The test result. Extra keys will not be checked.
"""
for key in required:
self.assertEqual(
@@ -192,31 +196,31 @@ class TestCase(unittest.TestCase):
)
-def DEBUG(target):
+def DEBUG(target: TV) -> TV:
"""A decorator to set the .loglevel attribute to logging.DEBUG.
Can apply to either a TestCase or an individual test method."""
- target.loglevel = logging.DEBUG
+ target.loglevel = logging.DEBUG # type: ignore[attr-defined]
return target
-def INFO(target):
+def INFO(target: TV) -> TV:
"""A decorator to set the .loglevel attribute to logging.INFO.
Can apply to either a TestCase or an individual test method."""
- target.loglevel = logging.INFO
+ target.loglevel = logging.INFO # type: ignore[attr-defined]
return target
-def logcontext_clean(target):
+def logcontext_clean(target: TV) -> TV:
"""A decorator which marks the TestCase or method as 'logcontext_clean'
... ie, any logcontext errors should cause a test failure
"""
- def logcontext_error(msg):
+ def logcontext_error(msg: str) -> NoReturn:
raise AssertionError("logcontext error: %s" % (msg))
patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error)
- return patcher(target)
+ return patcher(target) # type: ignore[call-overload]
class HomeserverTestCase(TestCase):
@@ -256,7 +260,7 @@ class HomeserverTestCase(TestCase):
method = getattr(self, methodName)
self._extra_config = getattr(method, "_extra_config", None)
- def setUp(self):
+ def setUp(self) -> None:
"""
Set up the TestCase by calling the homeserver constructor, optionally
hijacking the authentication system to return a fixed user, and then
@@ -285,7 +289,7 @@ class HomeserverTestCase(TestCase):
config=self.hs.config.server.listeners[0],
resource=self.resource,
server_version_string="1",
- max_request_body_size=1234,
+ max_request_body_size=4096,
reactor=self.reactor,
)
@@ -307,7 +311,9 @@ class HomeserverTestCase(TestCase):
)
)
- async def get_user_by_access_token(token=None, allow_guest=False):
+ async def get_user_by_access_token(
+ token: Optional[str] = None, allow_guest: bool = False
+ ) -> JsonDict:
assert self.helper.auth_user_id is not None
return {
"user": UserID.from_string(self.helper.auth_user_id),
@@ -315,7 +321,11 @@ class HomeserverTestCase(TestCase):
"is_guest": False,
}
- async def get_user_by_req(request, allow_guest=False, rights="access"):
+ async def get_user_by_req(
+ request: SynapseRequest,
+ allow_guest: bool = False,
+ allow_expired: bool = False,
+ ) -> Requester:
assert self.helper.auth_user_id is not None
return create_requester(
UserID.from_string(self.helper.auth_user_id),
@@ -340,11 +350,11 @@ class HomeserverTestCase(TestCase):
if hasattr(self, "prepare"):
self.prepare(self.reactor, self.clock, self.hs)
- def tearDown(self):
+ def tearDown(self) -> None:
# Reset to not use frozen dicts.
events.USE_FROZEN_DICTS = False
- def wait_on_thread(self, deferred, timeout=10):
+ def wait_on_thread(self, deferred: Deferred, timeout: int = 10) -> None:
"""
Wait until a Deferred is done, where it's waiting on a real thread.
"""
@@ -375,7 +385,7 @@ class HomeserverTestCase(TestCase):
clock (synapse.util.Clock): The Clock, associated with the reactor.
Returns:
- A homeserver (synapse.server.HomeServer) suitable for testing.
+ A homeserver suitable for testing.
Function to be overridden in subclasses.
"""
@@ -409,7 +419,7 @@ class HomeserverTestCase(TestCase):
"/_synapse/admin": servlet_resource,
}
- def default_config(self):
+ def default_config(self) -> JsonDict:
"""
Get a default HomeServer config dict.
"""
@@ -422,7 +432,9 @@ class HomeserverTestCase(TestCase):
return config
- def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
"""
Prepare for the test. This involves things like mocking out parts of
the homeserver, or building test data common across the whole test
@@ -520,7 +532,7 @@ class HomeserverTestCase(TestCase):
config_obj.parse_config_dict(config, "", "")
kwargs["config"] = config_obj
- async def run_bg_updates():
+ async def run_bg_updates() -> None:
with LoggingContext("run_bg_updates"):
self.get_success(stor.db_pool.updates.run_background_updates(False))
@@ -539,11 +551,7 @@ class HomeserverTestCase(TestCase):
"""
self.reactor.pump([by] * 100)
- def get_success(
- self,
- d: Awaitable[TV],
- by: float = 0.0,
- ) -> TV:
+ def get_success(self, d: Awaitable[TV], by: float = 0.0) -> TV:
deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]
self.pump(by=by)
return self.successResultOf(deferred)
@@ -619,20 +627,16 @@ class HomeserverTestCase(TestCase):
want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str)
want_mac_digest = want_mac.hexdigest()
- body = json.dumps(
- {
- "nonce": nonce,
- "username": username,
- "displayname": displayname,
- "password": password,
- "admin": admin,
- "mac": want_mac_digest,
- "inhibit_login": True,
- }
- )
- channel = self.make_request(
- "POST", "/_synapse/admin/v1/register", body.encode("utf8")
- )
+ body = {
+ "nonce": nonce,
+ "username": username,
+ "displayname": displayname,
+ "password": password,
+ "admin": admin,
+ "mac": want_mac_digest,
+ "inhibit_login": True,
+ }
+ channel = self.make_request("POST", "/_synapse/admin/v1/register", body)
self.assertEqual(channel.code, 200, channel.json_body)
user_id = channel.json_body["user_id"]
@@ -673,21 +677,34 @@ class HomeserverTestCase(TestCase):
username: str,
password: str,
device_id: Optional[str] = None,
+ additional_request_fields: Optional[Dict[str, str]] = None,
custom_headers: Optional[Iterable[CustomHeaderType]] = None,
) -> str:
"""
- Log in a user, and get an access token. Requires the Login API be
- registered.
+ Log in a user, and get an access token. Requires the Login API be registered.
+ Args:
+ username: The localpart to assign to the new user.
+ password: The password to assign to the new user.
+ device_id: An optional device ID to assign to the new device created during
+ login.
+ additional_request_fields: A dictionary containing any additional /login
+ request fields and their values.
+ custom_headers: Custom HTTP headers and values to add to the /login request.
+
+ Returns:
+ The newly registered user's Matrix ID.
"""
body = {"type": "m.login.password", "user": username, "password": password}
if device_id:
body["device_id"] = device_id
+ if additional_request_fields:
+ body.update(additional_request_fields)
channel = self.make_request(
"POST",
"/_matrix/client/r0/login",
- json.dumps(body).encode("utf8"),
+ body,
custom_headers=custom_headers,
)
self.assertEqual(channel.code, 200, channel.result)
@@ -762,7 +779,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
OTHER_SERVER_NAME = "other.example.com"
OTHER_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test")
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
# poke the other server's signing key into the key store, so that we don't
@@ -780,7 +797,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
verify_key_id,
FetchKeyResult(
verify_key=verify_key,
- valid_until_ts=clock.time_msec() + 1000,
+ valid_until_ts=clock.time_msec() + 10000,
),
)
],
@@ -838,7 +855,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
client_ip=client_ip,
)
- def add_hashes_and_signatures(
+ def add_hashes_and_signatures_from_other_server(
self,
event_dict: JsonDict,
room_version: RoomVersion = KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
@@ -886,7 +903,7 @@ def _auth_header_for_request(
)
-def override_config(extra_config):
+def override_config(extra_config: JsonDict) -> Callable[[TV], TV]:
"""A decorator which can be applied to test functions to give additional HS config
For use
@@ -899,12 +916,13 @@ def override_config(extra_config):
...
Args:
- extra_config(dict): Additional config settings to be merged into the default
+ extra_config: Additional config settings to be merged into the default
config dict before instantiating the test homeserver.
"""
- def decorator(func):
- func._extra_config = extra_config
+ def decorator(func: TV) -> TV:
+ # This attribute is being defined.
+ func._extra_config = extra_config # type: ignore[attr-defined]
return func
return decorator
diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py
index bee66dee43..e8b6246ab5 100644
--- a/tests/util/test_dict_cache.py
+++ b/tests/util/test_dict_cache.py
@@ -20,7 +20,7 @@ from tests import unittest
class DictCacheTestCase(unittest.TestCase):
def setUp(self):
- self.cache = DictionaryCache("foobar")
+ self.cache = DictionaryCache("foobar", max_entries=10)
def test_simple_cache_hit_full(self):
key = "test_simple_cache_hit_full"
@@ -76,13 +76,13 @@ class DictCacheTestCase(unittest.TestCase):
seq = self.cache.sequence
test_value_1 = {"test": "test_simple_cache_hit_miss_partial"}
- self.cache.update(seq, key, test_value_1, fetched_keys=set("test"))
+ self.cache.update(seq, key, test_value_1, fetched_keys={"test"})
seq = self.cache.sequence
test_value_2 = {"test2": "test_simple_cache_hit_miss_partial2"}
- self.cache.update(seq, key, test_value_2, fetched_keys=set("test2"))
+ self.cache.update(seq, key, test_value_2, fetched_keys={"test2"})
- c = self.cache.get(key)
+ c = self.cache.get(key, dict_keys=["test", "test2"])
self.assertEqual(
{
"test": "test_simple_cache_hit_miss_partial",
@@ -90,3 +90,30 @@ class DictCacheTestCase(unittest.TestCase):
},
c.value,
)
+ self.assertEqual(c.full, False)
+
+ def test_invalidation(self):
+ """Test that the partial dict and full dicts get invalidated
+ separately.
+ """
+ key = "some_key"
+
+ seq = self.cache.sequence
+ # start by populating a "full dict" entry
+ self.cache.update(seq, key, {"a": "b", "c": "d"})
+
+ # add a bunch of individual entries, also keeping the individual
+ # entry for "a" warm.
+ for i in range(20):
+ self.cache.get(key, ["a"])
+ self.cache.update(seq, f"key{i}", {1: 2})
+
+ # We should have evicted the full dict...
+ r = self.cache.get(key)
+ self.assertFalse(r.full)
+ self.assertTrue("c" not in r.value)
+
+ # ... but kept the "a" entry that we kept querying.
+ r = self.cache.get(key, dict_keys=["a"])
+ self.assertFalse(r.full)
+ self.assertEqual(r.value, {"a": "b"})
diff --git a/tests/util/test_macaroons.py b/tests/util/test_macaroons.py
new file mode 100644
index 0000000000..32125f7bb7
--- /dev/null
+++ b/tests/util/test_macaroons.py
@@ -0,0 +1,146 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pymacaroons.exceptions import MacaroonVerificationFailedException
+
+from synapse.util.macaroons import MacaroonGenerator, OidcSessionData
+
+from tests.server import get_clock
+from tests.unittest import TestCase
+
+
+class MacaroonGeneratorTestCase(TestCase):
+ def setUp(self):
+ self.reactor, hs_clock = get_clock()
+ self.macaroon_generator = MacaroonGenerator(hs_clock, "tesths", b"verysecret")
+ self.other_macaroon_generator = MacaroonGenerator(
+ hs_clock, "tesths", b"anothersecretkey"
+ )
+
+ def test_guest_access_token(self):
+ """Test the generation and verification of guest access tokens"""
+ token = self.macaroon_generator.generate_guest_access_token("@user:tesths")
+ user_id = self.macaroon_generator.verify_guest_token(token)
+ self.assertEqual(user_id, "@user:tesths")
+
+ # Raises with another secret key
+ with self.assertRaises(MacaroonVerificationFailedException):
+ self.other_macaroon_generator.verify_guest_token(token)
+
+ # Check that an old access token without the guest caveat does not work
+ macaroon = self.macaroon_generator._generate_base_macaroon("access")
+ macaroon.add_first_party_caveat(f"user_id = {user_id}")
+ macaroon.add_first_party_caveat("nonce = 0123456789abcdef")
+ token = macaroon.serialize()
+
+ with self.assertRaises(MacaroonVerificationFailedException):
+ self.macaroon_generator.verify_guest_token(token)
+
+ def test_delete_pusher_token(self):
+ """Test the generation and verification of delete_pusher tokens"""
+ token = self.macaroon_generator.generate_delete_pusher_token(
+ "@user:tesths", "m.mail", "john@example.com"
+ )
+ user_id = self.macaroon_generator.verify_delete_pusher_token(
+ token, "m.mail", "john@example.com"
+ )
+ self.assertEqual(user_id, "@user:tesths")
+
+ # Raises with another secret key
+ with self.assertRaises(MacaroonVerificationFailedException):
+ self.other_macaroon_generator.verify_delete_pusher_token(
+ token, "m.mail", "john@example.com"
+ )
+
+ # Raises when verifying for another pushkey
+ with self.assertRaises(MacaroonVerificationFailedException):
+ self.macaroon_generator.verify_delete_pusher_token(
+ token, "m.mail", "other@example.com"
+ )
+
+ # Raises when verifying for another app_id
+ with self.assertRaises(MacaroonVerificationFailedException):
+ self.macaroon_generator.verify_delete_pusher_token(
+ token, "somethingelse", "john@example.com"
+ )
+
+ # Check that an old token without the app_id and pushkey still works
+ macaroon = self.macaroon_generator._generate_base_macaroon("delete_pusher")
+ macaroon.add_first_party_caveat("user_id = @user:tesths")
+ token = macaroon.serialize()
+ user_id = self.macaroon_generator.verify_delete_pusher_token(
+ token, "m.mail", "john@example.com"
+ )
+ self.assertEqual(user_id, "@user:tesths")
+
+ def test_short_term_login_token(self):
+ """Test the generation and verification of short-term login tokens"""
+ token = self.macaroon_generator.generate_short_term_login_token(
+ user_id="@user:tesths",
+ auth_provider_id="oidc",
+ auth_provider_session_id="sid",
+ duration_in_ms=2 * 60 * 1000,
+ )
+
+ info = self.macaroon_generator.verify_short_term_login_token(token)
+ self.assertEqual(info.user_id, "@user:tesths")
+ self.assertEqual(info.auth_provider_id, "oidc")
+ self.assertEqual(info.auth_provider_session_id, "sid")
+
+ # Raises with another secret key
+ with self.assertRaises(MacaroonVerificationFailedException):
+ self.other_macaroon_generator.verify_short_term_login_token(token)
+
+ # Wait a minute
+ self.reactor.pump([60])
+ # Shouldn't raise
+ self.macaroon_generator.verify_short_term_login_token(token)
+ # Wait another minute
+ self.reactor.pump([60])
+ # Should raise since it expired
+ with self.assertRaises(MacaroonVerificationFailedException):
+ self.macaroon_generator.verify_short_term_login_token(token)
+
+ def test_oidc_session_token(self):
+ """Test the generation and verification of OIDC session cookies"""
+ state = "arandomstate"
+ session_data = OidcSessionData(
+ idp_id="oidc",
+ nonce="nonce",
+ client_redirect_url="https://example.com/",
+ ui_auth_session_id="",
+ )
+ token = self.macaroon_generator.generate_oidc_session_token(
+ state, session_data, duration_in_ms=2 * 60 * 1000
+ ).encode("utf-8")
+ info = self.macaroon_generator.verify_oidc_session_token(token, state)
+ self.assertEqual(session_data, info)
+
+ # Raises with another secret key
+ with self.assertRaises(MacaroonVerificationFailedException):
+ self.other_macaroon_generator.verify_oidc_session_token(token, state)
+
+ # Should raise with another state
+ with self.assertRaises(MacaroonVerificationFailedException):
+ self.macaroon_generator.verify_oidc_session_token(token, "anotherstate")
+
+ # Wait a minute
+ self.reactor.pump([60])
+ # Shouldn't raise
+ self.macaroon_generator.verify_oidc_session_token(token, state)
+ # Wait another minute
+ self.reactor.pump([60])
+ # Should raise since it expired
+ with self.assertRaises(MacaroonVerificationFailedException):
+ self.macaroon_generator.verify_oidc_session_token(token, state)
diff --git a/tests/utils.py b/tests/utils.py
index 3059c453d5..d2c6d1e852 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -15,12 +15,17 @@
import atexit
import os
+from typing import Any, Callable, Dict, List, Tuple, Union, overload
+
+import attr
+from typing_extensions import Literal, ParamSpec
from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.logging.context import current_context, set_current_context
+from synapse.server import HomeServer
from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
@@ -50,12 +55,11 @@ SQLITE_PERSIST_DB = os.environ.get("SYNAPSE_TEST_PERSIST_SQLITE_DB") is not None
POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres"
-def setupdb():
+def setupdb() -> None:
# If we're using PostgreSQL, set up the db once
if USE_POSTGRES_FOR_TESTS:
# create a PostgresEngine
db_engine = create_engine({"name": "psycopg2", "args": {}})
-
# connect to postgres to create the base database.
db_conn = db_engine.module.connect(
user=POSTGRES_USER,
@@ -64,7 +68,7 @@ def setupdb():
password=POSTGRES_PASSWORD,
dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
)
- db_conn.autocommit = True
+ db_engine.attempt_to_set_autocommit(db_conn, autocommit=True)
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
cur.execute(
@@ -82,11 +86,11 @@ def setupdb():
port=POSTGRES_PORT,
password=POSTGRES_PASSWORD,
)
- db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
- prepare_database(db_conn, db_engine, None)
- db_conn.close()
+ logging_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
+ prepare_database(logging_conn, db_engine, None)
+ logging_conn.close()
- def _cleanup():
+ def _cleanup() -> None:
db_conn = db_engine.module.connect(
user=POSTGRES_USER,
host=POSTGRES_HOST,
@@ -94,7 +98,7 @@ def setupdb():
password=POSTGRES_PASSWORD,
dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
)
- db_conn.autocommit = True
+ db_engine.attempt_to_set_autocommit(db_conn, autocommit=True)
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
cur.close()
@@ -103,7 +107,19 @@ def setupdb():
atexit.register(_cleanup)
-def default_config(name, parse=False):
+@overload
+def default_config(name: str, parse: Literal[False] = ...) -> Dict[str, object]:
+ ...
+
+
+@overload
+def default_config(name: str, parse: Literal[True]) -> HomeServerConfig:
+ ...
+
+
+def default_config(
+ name: str, parse: bool = False
+) -> Union[Dict[str, object], HomeServerConfig]:
"""
Create a reasonable test config.
"""
@@ -151,6 +167,7 @@ def default_config(name, parse=False):
"local": {"per_second": 10000, "burst_count": 10000},
"remote": {"per_second": 10000, "burst_count": 10000},
},
+ "rc_joins_per_room": {"per_second": 10000, "burst_count": 10000},
"rc_invites": {
"per_room": {"per_second": 10000, "burst_count": 10000},
"per_user": {"per_second": 10000, "burst_count": 10000},
@@ -169,7 +186,7 @@ def default_config(name, parse=False):
# disable user directory updates, because they get done in the
# background, which upsets the test runner.
"update_user_directory": False,
- "caches": {"global_factor": 1},
+ "caches": {"global_factor": 1, "sync_response_cache_duration": 0},
"listeners": [{"port": 0, "type": "http"}],
}
@@ -181,90 +198,122 @@ def default_config(name, parse=False):
return config_dict
-def mock_getRawHeaders(headers=None):
+def mock_getRawHeaders(headers=None): # type: ignore[no-untyped-def]
headers = headers if headers is not None else {}
- def getRawHeaders(name, default=None):
+ def getRawHeaders(name, default=None): # type: ignore[no-untyped-def]
+ # If the requested header is present, the real twisted function returns
+ # List[str] if name is a str and List[bytes] if name is a bytes.
+ # This mock doesn't support that behaviour.
+ # Fortunately, none of the current callers of mock_getRawHeaders() provide a
+ # headers dict, so we don't encounter this discrepancy in practice.
return headers.get(name, default)
return getRawHeaders
+P = ParamSpec("P")
+
+
+@attr.s(slots=True, auto_attribs=True)
+class Timer:
+ absolute_time: float
+ callback: Callable[[], None]
+ expired: bool
+
+
+# TODO: Make this generic over a ParamSpec?
+@attr.s(slots=True, auto_attribs=True)
+class Looper:
+ func: Callable[..., Any]
+ interval: float # seconds
+ last: float
+ args: Tuple[object, ...]
+ kwargs: Dict[str, object]
+
+
class MockClock:
- now = 1000
+ now = 1000.0
- def __init__(self):
- # list of lists of [absolute_time, callback, expired] in no particular
- # order
- self.timers = []
- self.loopers = []
+ def __init__(self) -> None:
+ # Timers in no particular order
+ self.timers: List[Timer] = []
+ self.loopers: List[Looper] = []
- def time(self):
+ def time(self) -> float:
return self.now
- def time_msec(self):
- return self.time() * 1000
+ def time_msec(self) -> int:
+ return int(self.time() * 1000)
- def call_later(self, delay, callback, *args, **kwargs):
+ def call_later(
+ self,
+ delay: float,
+ callback: Callable[P, object],
+ *args: P.args,
+ **kwargs: P.kwargs,
+ ) -> Timer:
ctx = current_context()
- def wrapped_callback():
+ def wrapped_callback() -> None:
set_current_context(ctx)
callback(*args, **kwargs)
- t = [self.now + delay, wrapped_callback, False]
+ t = Timer(self.now + delay, wrapped_callback, False)
self.timers.append(t)
return t
- def looping_call(self, function, interval, *args, **kwargs):
- self.loopers.append([function, interval / 1000.0, self.now, args, kwargs])
-
- def cancel_call_later(self, timer, ignore_errs=False):
- if timer[2]:
+ def looping_call(
+ self,
+ function: Callable[P, object],
+ interval: float,
+ *args: P.args,
+ **kwargs: P.kwargs,
+ ) -> None:
+ # This type-ignore should be redundant once we use a mypy release with
+ # https://github.com/python/mypy/pull/12668.
+ self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs)) # type: ignore[arg-type]
+
+ def cancel_call_later(self, timer: Timer, ignore_errs: bool = False) -> None:
+ if timer.expired:
if not ignore_errs:
raise Exception("Cannot cancel an expired timer")
- timer[2] = True
+ timer.expired = True
self.timers = [t for t in self.timers if t != timer]
# For unit testing
- def advance_time(self, secs):
+ def advance_time(self, secs: float) -> None:
self.now += secs
timers = self.timers
self.timers = []
for t in timers:
- time, callback, expired = t
-
- if expired:
+ if t.expired:
raise Exception("Timer already expired")
- if self.now >= time:
- t[2] = True
- callback()
+ if self.now >= t.absolute_time:
+ t.expired = True
+ t.callback()
else:
self.timers.append(t)
for looped in self.loopers:
- func, interval, last, args, kwargs = looped
- if last + interval < self.now:
- func(*args, **kwargs)
- looped[2] = self.now
+ if looped.last + looped.interval < self.now:
+ looped.func(*looped.args, **looped.kwargs)
+ looped.last = self.now
- def advance_time_msec(self, ms):
+ def advance_time_msec(self, ms: float) -> None:
self.advance_time(ms / 1000.0)
- def time_bound_deferred(self, d, *args, **kwargs):
- # We don't bother timing things out for now.
- return d
-
-async def create_room(hs, room_id: str, creator_id: str):
+async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None:
"""Creates and persist a creation event for the given room"""
persistence_store = hs.get_storage_controllers().persistence
+ assert persistence_store is not None
store = hs.get_datastores().main
event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler()
|