diff --git a/tests/app/test_homeserver_start.py b/tests/app/test_homeserver_start.py
deleted file mode 100644
index cbcada0451..0000000000
--- a/tests/app/test_homeserver_start.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# Copyright 2021 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import synapse.app.homeserver
-from synapse.config._base import ConfigError
-
-from tests.config.utils import ConfigFileTestCase
-
-
-class HomeserverAppStartTestCase(ConfigFileTestCase):
- def test_wrong_start_caught(self):
- # Generate a config with a worker_app
- self.generate_config()
- # Add a blank line as otherwise the next addition ends up on a line with a comment
- self.add_lines_to_config([" "])
- self.add_lines_to_config(["worker_app: test_worker_app"])
-
- # Ensure that starting master process with worker config raises an exception
- with self.assertRaises(ConfigError):
- synapse.app.homeserver.setup(["-c", self.config_file])
diff --git a/tests/config/test_registration_config.py b/tests/config/test_registration_config.py
deleted file mode 100644
index 17a84d20d8..0000000000
--- a/tests/config/test_registration_config.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# Copyright 2021 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from synapse.config import ConfigError
-from synapse.config.homeserver import HomeServerConfig
-
-from tests.unittest import TestCase
-from tests.utils import default_config
-
-
-class RegistrationConfigTestCase(TestCase):
- def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self):
- """
- session_lifetime should logically be larger than, or at least as large as,
- all the different token lifetimes.
- Test that the user is faced with configuration errors if they make it
- smaller, as that configuration doesn't make sense.
- """
- config_dict = default_config("test")
-
- # First test all the error conditions
- with self.assertRaises(ConfigError):
- HomeServerConfig().parse_config_dict(
- {
- "session_lifetime": "30m",
- "nonrefreshable_access_token_lifetime": "31m",
- **config_dict,
- }
- )
-
- with self.assertRaises(ConfigError):
- HomeServerConfig().parse_config_dict(
- {
- "session_lifetime": "30m",
- "refreshable_access_token_lifetime": "31m",
- **config_dict,
- }
- )
-
- with self.assertRaises(ConfigError):
- HomeServerConfig().parse_config_dict(
- {
- "session_lifetime": "30m",
- "refresh_token_lifetime": "31m",
- **config_dict,
- }
- )
-
- # Then test all the fine conditions
- HomeServerConfig().parse_config_dict(
- {
- "session_lifetime": "31m",
- "nonrefreshable_access_token_lifetime": "31m",
- **config_dict,
- }
- )
-
- HomeServerConfig().parse_config_dict(
- {
- "session_lifetime": "31m",
- "refreshable_access_token_lifetime": "31m",
- **config_dict,
- }
- )
-
- HomeServerConfig().parse_config_dict(
- {"session_lifetime": "31m", "refresh_token_lifetime": "31m", **config_dict}
- )
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 17a9fb63a1..4d1e154578 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -22,7 +22,6 @@ import signedjson.sign
from nacl.signing import SigningKey
from signedjson.key import encode_verify_key_base64, get_verify_key
-from twisted.internet import defer
from twisted.internet.defer import Deferred, ensureDeferred
from synapse.api.errors import SynapseError
@@ -578,76 +577,6 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
- def test_get_multiple_keys_from_perspectives(self):
- """Check that we can correctly request multiple keys for the same server"""
-
- fetcher = PerspectivesKeyFetcher(self.hs)
-
- SERVER_NAME = "server2"
-
- testkey1 = signedjson.key.generate_signing_key("ver1")
- testverifykey1 = signedjson.key.get_verify_key(testkey1)
- testverifykey1_id = "ed25519:ver1"
-
- testkey2 = signedjson.key.generate_signing_key("ver2")
- testverifykey2 = signedjson.key.get_verify_key(testkey2)
- testverifykey2_id = "ed25519:ver2"
-
- VALID_UNTIL_TS = 200 * 1000
-
- response1 = self.build_perspectives_response(
- SERVER_NAME,
- testkey1,
- VALID_UNTIL_TS,
- )
- response2 = self.build_perspectives_response(
- SERVER_NAME,
- testkey2,
- VALID_UNTIL_TS,
- )
-
- async def post_json(destination, path, data, **kwargs):
- self.assertEqual(destination, self.mock_perspective_server.server_name)
- self.assertEqual(path, "/_matrix/key/v2/query")
-
- # check that the request is for the expected keys
- q = data["server_keys"]
-
- self.assertEqual(
- list(q[SERVER_NAME].keys()), [testverifykey1_id, testverifykey2_id]
- )
- return {"server_keys": [response1, response2]}
-
- self.http_client.post_json.side_effect = post_json
-
- # fire off two separate requests; they should get merged together into a
- # single HTTP hit.
- request1_d = defer.ensureDeferred(
- fetcher.get_keys(SERVER_NAME, [testverifykey1_id], 0)
- )
- request2_d = defer.ensureDeferred(
- fetcher.get_keys(SERVER_NAME, [testverifykey2_id], 0)
- )
-
- keys1 = self.get_success(request1_d)
- self.assertIn(testverifykey1_id, keys1)
- k = keys1[testverifykey1_id]
- self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
- self.assertEqual(k.verify_key, testverifykey1)
- self.assertEqual(k.verify_key.alg, "ed25519")
- self.assertEqual(k.verify_key.version, "ver1")
-
- keys2 = self.get_success(request2_d)
- self.assertIn(testverifykey2_id, keys2)
- k = keys2[testverifykey2_id]
- self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
- self.assertEqual(k.verify_key, testverifykey2)
- self.assertEqual(k.verify_key.alg, "ed25519")
- self.assertEqual(k.verify_key.version, "ver2")
-
- # finally, ensure that only one request was sent
- self.assertEqual(self.http_client.post_json.call_count, 1)
-
def test_get_perspectives_own_key(self):
"""Check that we can get the perspectives server's own keys
diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py
deleted file mode 100644
index a7031a55f2..0000000000
--- a/tests/federation/transport/test_client.py
+++ /dev/null
@@ -1,64 +0,0 @@
-# Copyright 2021 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import json
-
-from synapse.api.room_versions import RoomVersions
-from synapse.federation.transport.client import SendJoinParser
-
-from tests.unittest import TestCase
-
-
-class SendJoinParserTestCase(TestCase):
- def test_two_writes(self) -> None:
- """Test that the parser can sensibly deserialise an input given in two slices."""
- parser = SendJoinParser(RoomVersions.V1, True)
- parent_event = {
- "content": {
- "see_room_version_spec": "The event format changes depending on the room version."
- },
- "event_id": "$authparent",
- "room_id": "!somewhere:example.org",
- "type": "m.room.minimal_pdu",
- }
- state = {
- "content": {
- "see_room_version_spec": "The event format changes depending on the room version."
- },
- "event_id": "$DoNotThinkAboutTheEvent",
- "room_id": "!somewhere:example.org",
- "type": "m.room.minimal_pdu",
- }
- response = [
- 200,
- {
- "auth_chain": [parent_event],
- "origin": "matrix.org",
- "state": [state],
- },
- ]
- serialised_response = json.dumps(response).encode()
-
- # Send data to the parser
- parser.write(serialised_response[:100])
- parser.write(serialised_response[100:])
-
- # Retrieve the parsed SendJoinResponse
- parsed_response = parser.finish()
-
- # Sanity check the parsing gave us sensible data.
- self.assertEqual(len(parsed_response.auth_events), 1, parsed_response)
- self.assertEqual(len(parsed_response.state), 1, parsed_response)
- self.assertEqual(parsed_response.event_dict, {}, parsed_response)
- self.assertIsNone(parsed_response.event, parsed_response)
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 03b8b8615c..72e176da75 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -71,7 +71,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_short_term_login_token_gives_user_id(self):
token = self.macaroon_generator.generate_short_term_login_token(
- self.user1, "", duration_in_ms=5000
+ self.user1, "", 5000
)
res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
self.assertEqual(self.user1, res.user_id)
@@ -94,7 +94,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_short_term_login_token_cannot_replace_user_id(self):
token = self.macaroon_generator.generate_short_term_login_token(
- self.user1, "", duration_in_ms=5000
+ self.user1, "", 5000
)
macaroon = pymacaroons.Macaroon.deserialize(token)
@@ -213,6 +213,6 @@ class AuthTestCase(unittest.HomeserverTestCase):
def _get_macaroon(self):
token = self.macaroon_generator.generate_short_term_login_token(
- self.user1, "", duration_in_ms=5000
+ self.user1, "", 5000
)
return pymacaroons.Macaroon.deserialize(token)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 8705ff8943..b625995d12 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -66,13 +66,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "cas",
- request,
- "redirect_uri",
- None,
- new_user=True,
- auth_provider_session_id=None,
+ "@test_user:test", "cas", request, "redirect_uri", None, new_user=True
)
def test_map_cas_user_to_existing_user(self):
@@ -95,13 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "cas",
- request,
- "redirect_uri",
- None,
- new_user=False,
- auth_provider_session_id=None,
+ "@test_user:test", "cas", request, "redirect_uri", None, new_user=False
)
# Subsequent calls should map to the same mxid.
@@ -110,13 +98,7 @@ class CasHandlerTestCase(HomeserverTestCase):
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
)
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "cas",
- request,
- "redirect_uri",
- None,
- new_user=False,
- auth_provider_session_id=None,
+ "@test_user:test", "cas", request, "redirect_uri", None, new_user=False
)
def test_map_cas_user_to_invalid_localpart(self):
@@ -134,13 +116,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@f=c3=b6=c3=b6:test",
- "cas",
- request,
- "redirect_uri",
- None,
- new_user=True,
- auth_provider_session_id=None,
+ "@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True
)
@override_config(
@@ -184,13 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "cas",
- request,
- "redirect_uri",
- None,
- new_user=True,
- auth_provider_session_id=None,
+ "@test_user:test", "cas", request, "redirect_uri", None, new_user=True
)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index cfe3de5266..a25c89bd5b 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -252,6 +252,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
with patch.object(self.provider, "load_metadata", patched_load_metadata):
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
+ # Return empty key set if JWKS are not used
+ self.provider._scopes = [] # not asking the openid scope
+ self.http_client.get_json.reset_mock()
+ jwks = self.get_success(self.provider.load_jwks(force=True))
+ self.http_client.get_json.assert_not_called()
+ self.assertEqual(jwks, {"keys": []})
+
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_validate_config(self):
"""Provider metadatas are extensively validated."""
@@ -448,13 +455,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
- expected_user_id,
- "oidc",
- request,
- client_redirect_url,
- None,
- new_user=True,
- auth_provider_session_id=None,
+ expected_user_id, "oidc", request, client_redirect_url, None, new_user=True
)
self.provider._exchange_code.assert_called_once_with(code)
self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
@@ -481,58 +482,17 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.provider._fetch_userinfo.reset_mock()
# With userinfo fetching
- self.provider._user_profile_method = "userinfo_endpoint"
- token = {
- "type": "bearer",
- "access_token": "access_token",
- }
- self.provider._exchange_code = simple_async_mock(return_value=token)
+ self.provider._scopes = [] # do not ask the "openid" scope
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
- expected_user_id,
- "oidc",
- request,
- client_redirect_url,
- None,
- new_user=False,
- auth_provider_session_id=None,
+ expected_user_id, "oidc", request, client_redirect_url, None, new_user=False
)
self.provider._exchange_code.assert_called_once_with(code)
self.provider._parse_id_token.assert_not_called()
self.provider._fetch_userinfo.assert_called_once_with(token)
self.render_error.assert_not_called()
- # With an ID token, userinfo fetching and sid in the ID token
- self.provider._user_profile_method = "userinfo_endpoint"
- token = {
- "type": "bearer",
- "access_token": "access_token",
- "id_token": "id_token",
- }
- id_token = {
- "sid": "abcdefgh",
- }
- self.provider._parse_id_token = simple_async_mock(return_value=id_token)
- self.provider._exchange_code = simple_async_mock(return_value=token)
- auth_handler.complete_sso_login.reset_mock()
- self.provider._fetch_userinfo.reset_mock()
- self.get_success(self.handler.handle_oidc_callback(request))
-
- auth_handler.complete_sso_login.assert_called_once_with(
- expected_user_id,
- "oidc",
- request,
- client_redirect_url,
- None,
- new_user=False,
- auth_provider_session_id=id_token["sid"],
- )
- self.provider._exchange_code.assert_called_once_with(code)
- self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.provider._fetch_userinfo.assert_called_once_with(token)
- self.render_error.assert_not_called()
-
# Handle userinfo fetching error
self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
self.get_success(self.handler.handle_oidc_callback(request))
@@ -816,7 +776,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url,
{"phone": "1234567"},
new_user=True,
- auth_provider_session_id=None,
)
@override_config({"oidc_config": DEFAULT_CONFIG})
@@ -831,13 +790,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "oidc",
- ANY,
- ANY,
- None,
- new_user=True,
- auth_provider_session_id=None,
+ "@test_user:test", "oidc", ANY, ANY, None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@@ -848,13 +801,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user_2:test",
- "oidc",
- ANY,
- ANY,
- None,
- new_user=True,
- auth_provider_session_id=None,
+ "@test_user_2:test", "oidc", ANY, ANY, None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@@ -891,26 +838,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- user.to_string(),
- "oidc",
- ANY,
- ANY,
- None,
- new_user=False,
- auth_provider_session_id=None,
+ user.to_string(), "oidc", ANY, ANY, None, new_user=False
)
auth_handler.complete_sso_login.reset_mock()
# Subsequent calls should map to the same mxid.
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- user.to_string(),
- "oidc",
- ANY,
- ANY,
- None,
- new_user=False,
- auth_provider_session_id=None,
+ user.to_string(), "oidc", ANY, ANY, None, new_user=False
)
auth_handler.complete_sso_login.reset_mock()
@@ -925,13 +860,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- user.to_string(),
- "oidc",
- ANY,
- ANY,
- None,
- new_user=False,
- auth_provider_session_id=None,
+ user.to_string(), "oidc", ANY, ANY, None, new_user=False
)
auth_handler.complete_sso_login.reset_mock()
@@ -967,13 +896,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- "@TEST_USER_2:test",
- "oidc",
- ANY,
- ANY,
- None,
- new_user=False,
- auth_provider_session_id=None,
+ "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False
)
@override_config({"oidc_config": DEFAULT_CONFIG})
@@ -1011,13 +934,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
# test_user is already taken, so test_user1 gets registered instead.
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user1:test",
- "oidc",
- ANY,
- ANY,
- None,
- new_user=True,
- auth_provider_session_id=None,
+ "@test_user1:test", "oidc", ANY, ANY, None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@@ -1101,13 +1018,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@tester:test",
- "oidc",
- ANY,
- ANY,
- None,
- new_user=True,
- auth_provider_session_id=None,
+ "@tester:test", "oidc", ANY, ANY, None, new_user=True
)
@override_config(
@@ -1132,13 +1043,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@tester:test",
- "oidc",
- ANY,
- ANY,
- None,
- new_user=True,
- auth_provider_session_id=None,
+ "@tester:test", "oidc", ANY, ANY, None, new_user=True
)
@override_config(
@@ -1251,7 +1156,7 @@ async def _make_callback_with_userinfo(
handler = hs.get_oidc_handler()
provider = handler._providers["oidc"]
- provider._exchange_code = simple_async_mock(return_value={"id_token": ""})
+ provider._exchange_code = simple_async_mock(return_value={})
provider._parse_id_token = simple_async_mock(return_value=userinfo)
provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index e5a6a6c747..7b95844b55 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -32,7 +32,7 @@ from synapse.handlers.room_summary import _child_events_comparison_key, _RoomEnt
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID, create_requester
+from synapse.types import JsonDict, UserID
from tests import unittest
@@ -249,7 +249,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
self._assert_rooms(result, expected)
result = self.get_success(
- self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ self.handler.get_room_hierarchy(self.user, self.space)
)
self._assert_hierarchy(result, expected)
@@ -263,9 +263,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
expected = [(self.space, [self.room]), (self.room, ())]
self._assert_rooms(result, expected)
- result = self.get_success(
- self.handler.get_room_hierarchy(create_requester(user2), self.space)
- )
+ result = self.get_success(self.handler.get_room_hierarchy(user2, self.space))
self._assert_hierarchy(result, expected)
# If the space is made invite-only, it should no longer be viewable.
@@ -276,10 +274,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
tok=self.token,
)
self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError)
- self.get_failure(
- self.handler.get_room_hierarchy(create_requester(user2), self.space),
- AuthError,
- )
+ self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError)
# If the space is made world-readable it should return a result.
self.helper.send_state(
@@ -291,9 +286,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
result = self.get_success(self.handler.get_space_summary(user2, self.space))
self._assert_rooms(result, expected)
- result = self.get_success(
- self.handler.get_room_hierarchy(create_requester(user2), self.space)
- )
+ result = self.get_success(self.handler.get_room_hierarchy(user2, self.space))
self._assert_hierarchy(result, expected)
# Make it not world-readable again and confirm it results in an error.
@@ -304,10 +297,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
tok=self.token,
)
self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError)
- self.get_failure(
- self.handler.get_room_hierarchy(create_requester(user2), self.space),
- AuthError,
- )
+ self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError)
# Join the space and results should be returned.
self.helper.invite(self.space, targ=user2, tok=self.token)
@@ -315,9 +305,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
result = self.get_success(self.handler.get_space_summary(user2, self.space))
self._assert_rooms(result, expected)
- result = self.get_success(
- self.handler.get_room_hierarchy(create_requester(user2), self.space)
- )
+ result = self.get_success(self.handler.get_room_hierarchy(user2, self.space))
self._assert_hierarchy(result, expected)
# Attempting to view an unknown room returns the same error.
@@ -326,9 +314,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
AuthError,
)
self.get_failure(
- self.handler.get_room_hierarchy(
- create_requester(user2), "#not-a-space:" + self.hs.hostname
- ),
+ self.handler.get_room_hierarchy(user2, "#not-a-space:" + self.hs.hostname),
AuthError,
)
@@ -336,10 +322,10 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
"""In-flight room hierarchy requests are deduplicated."""
# Run two `get_room_hierarchy` calls up until they block.
deferred1 = ensureDeferred(
- self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ self.handler.get_room_hierarchy(self.user, self.space)
)
deferred2 = ensureDeferred(
- self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ self.handler.get_room_hierarchy(self.user, self.space)
)
# Complete the two calls.
@@ -354,7 +340,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
# A subsequent `get_room_hierarchy` call should not reuse the result.
result3 = self.get_success(
- self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ self.handler.get_room_hierarchy(self.user, self.space)
)
self._assert_hierarchy(result3, expected)
self.assertIsNot(result1, result3)
@@ -373,11 +359,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
# Run two `get_room_hierarchy` calls for different users up until they block.
deferred1 = ensureDeferred(
- self.handler.get_room_hierarchy(create_requester(self.user), self.space)
- )
- deferred2 = ensureDeferred(
- self.handler.get_room_hierarchy(create_requester(user2), self.space)
+ self.handler.get_room_hierarchy(self.user, self.space)
)
+ deferred2 = ensureDeferred(self.handler.get_room_hierarchy(user2, self.space))
# Complete the two calls.
result1 = self.get_success(deferred1)
@@ -481,9 +465,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
]
self._assert_rooms(result, expected)
- result = self.get_success(
- self.handler.get_room_hierarchy(create_requester(user2), self.space)
- )
+ result = self.get_success(self.handler.get_room_hierarchy(user2, self.space))
self._assert_hierarchy(result, expected)
def test_complex_space(self):
@@ -525,7 +507,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
self._assert_rooms(result, expected)
result = self.get_success(
- self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ self.handler.get_room_hierarchy(self.user, self.space)
)
self._assert_hierarchy(result, expected)
@@ -540,9 +522,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
room_ids.append(self.room)
result = self.get_success(
- self.handler.get_room_hierarchy(
- create_requester(self.user), self.space, limit=7
- )
+ self.handler.get_room_hierarchy(self.user, self.space, limit=7)
)
# The result should have the space and all of the links, plus some of the
# rooms and a pagination token.
@@ -554,10 +534,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
# Check the next page.
result = self.get_success(
self.handler.get_room_hierarchy(
- create_requester(self.user),
- self.space,
- limit=5,
- from_token=result["next_batch"],
+ self.user, self.space, limit=5, from_token=result["next_batch"]
)
)
# The result should have the space and the room in it, along with a link
@@ -577,22 +554,20 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
room_ids.append(self.room)
result = self.get_success(
- self.handler.get_room_hierarchy(
- create_requester(self.user), self.space, limit=7
- )
+ self.handler.get_room_hierarchy(self.user, self.space, limit=7)
)
self.assertIn("next_batch", result)
# Changing the room ID, suggested-only, or max-depth causes an error.
self.get_failure(
self.handler.get_room_hierarchy(
- create_requester(self.user), self.room, from_token=result["next_batch"]
+ self.user, self.room, from_token=result["next_batch"]
),
SynapseError,
)
self.get_failure(
self.handler.get_room_hierarchy(
- create_requester(self.user),
+ self.user,
self.space,
suggested_only=True,
from_token=result["next_batch"],
@@ -601,19 +576,14 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self.get_failure(
self.handler.get_room_hierarchy(
- create_requester(self.user),
- self.space,
- max_depth=0,
- from_token=result["next_batch"],
+ self.user, self.space, max_depth=0, from_token=result["next_batch"]
),
SynapseError,
)
# An invalid token is ignored.
self.get_failure(
- self.handler.get_room_hierarchy(
- create_requester(self.user), self.space, from_token="foo"
- ),
+ self.handler.get_room_hierarchy(self.user, self.space, from_token="foo"),
SynapseError,
)
@@ -639,18 +609,14 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
# Test just the space itself.
result = self.get_success(
- self.handler.get_room_hierarchy(
- create_requester(self.user), self.space, max_depth=0
- )
+ self.handler.get_room_hierarchy(self.user, self.space, max_depth=0)
)
expected: List[Tuple[str, Iterable[str]]] = [(spaces[0], [rooms[0], spaces[1]])]
self._assert_hierarchy(result, expected)
# A single additional layer.
result = self.get_success(
- self.handler.get_room_hierarchy(
- create_requester(self.user), self.space, max_depth=1
- )
+ self.handler.get_room_hierarchy(self.user, self.space, max_depth=1)
)
expected += [
(rooms[0], ()),
@@ -660,9 +626,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
# A few layers.
result = self.get_success(
- self.handler.get_room_hierarchy(
- create_requester(self.user), self.space, max_depth=3
- )
+ self.handler.get_room_hierarchy(self.user, self.space, max_depth=3)
)
expected += [
(rooms[1], ()),
@@ -693,7 +657,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
self._assert_rooms(result, expected)
result = self.get_success(
- self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ self.handler.get_room_hierarchy(self.user, self.space)
)
self._assert_hierarchy(result, expected)
@@ -775,7 +739,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
new=summarize_remote_room_hierarchy,
):
result = self.get_success(
- self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ self.handler.get_room_hierarchy(self.user, self.space)
)
self._assert_hierarchy(result, expected)
@@ -942,7 +906,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
new=summarize_remote_room_hierarchy,
):
result = self.get_success(
- self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ self.handler.get_room_hierarchy(self.user, self.space)
)
self._assert_hierarchy(result, expected)
@@ -1000,7 +964,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
new=summarize_remote_room_hierarchy,
):
result = self.get_success(
- self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ self.handler.get_room_hierarchy(self.user, self.space)
)
self._assert_hierarchy(result, expected)
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 50551aa6e3..8cfc184fef 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -130,13 +130,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "saml",
- request,
- "redirect_uri",
- None,
- new_user=True,
- auth_provider_session_id=None,
+ "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
)
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
@@ -162,13 +156,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "saml",
- request,
- "",
- None,
- new_user=False,
- auth_provider_session_id=None,
+ "@test_user:test", "saml", request, "", None, new_user=False
)
# Subsequent calls should map to the same mxid.
@@ -177,13 +165,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
self.handler._handle_authn_response(request, saml_response, "")
)
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "saml",
- request,
- "",
- None,
- new_user=False,
- auth_provider_session_id=None,
+ "@test_user:test", "saml", request, "", None, new_user=False
)
def test_map_saml_response_to_invalid_localpart(self):
@@ -231,13 +213,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# test_user is already taken, so test_user1 gets registered instead.
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user1:test",
- "saml",
- request,
- "",
- None,
- new_user=True,
- auth_provider_session_id=None,
+ "@test_user1:test", "saml", request, "", None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@@ -333,13 +309,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "saml",
- request,
- "redirect_uri",
- None,
- new_user=True,
- auth_provider_session_id=None,
+ "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
)
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index f8cba7b645..90f800e564 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -128,7 +128,6 @@ class EmailPusherTests(HomeserverTestCase):
)
self.auth_handler = hs.get_auth_handler()
- self.store = hs.get_datastore()
def test_need_validated_email(self):
"""Test that we can only add an email pusher if the user has validated
@@ -409,7 +408,13 @@ class EmailPusherTests(HomeserverTestCase):
self.hs.get_datastore().db_pool.updates._all_done = False
# Now let's actually drive the updates to completion
- self.wait_for_background_updates()
+ while not self.get_success(
+ self.hs.get_datastore().db_pool.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.hs.get_datastore().db_pool.updates.do_next_background_update(100),
+ by=0.1,
+ )
# Check that all pushers with unlinked addresses were deleted
pushers = self.get_success(
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 596ba5a0c9..0a6e4795ee 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -17,7 +17,6 @@ from unittest.mock import patch
from synapse.api.room_versions import RoomVersion
from synapse.rest import admin
from synapse.rest.client import login, room, sync
-from synapse.storage.util.id_generators import MultiWriterIdGenerator
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
@@ -194,10 +193,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
#
# Worker2's event stream position will not advance until we call
# __aexit__ again.
- worker_store2 = worker_hs2.get_datastore()
- assert isinstance(worker_store2._stream_id_gen, MultiWriterIdGenerator)
-
- actx = worker_store2._stream_id_gen.get_next()
+ actx = worker_hs2.get_datastore()._stream_id_gen.get_next()
self.get_success(actx.__aenter__())
response = self.helper.send(room_id1, body="Hi!", tok=self.other_access_token)
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 3adadcb46b..af849bd471 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
import os
import urllib.parse
-from http import HTTPStatus
from unittest.mock import Mock
from twisted.internet.defer import Deferred
@@ -41,7 +41,7 @@ class VersionTestCase(unittest.HomeserverTestCase):
def test_version_string(self):
channel = self.make_request("GET", self.url, shorthand=False)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
{"server_version", "python_version"}, set(channel.json_body.keys())
)
@@ -70,11 +70,11 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
content={"localpart": "test"},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
group_id = channel.json_body["group_id"]
- self._check_group(group_id, expect_code=HTTPStatus.OK)
+ self._check_group(group_id, expect_code=200)
# Invite/join another user
@@ -82,13 +82,13 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={}
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
url = "/groups/%s/self/accept_invite" % (group_id,)
channel = self.make_request(
"PUT", url.encode("ascii"), access_token=self.other_user_token, content={}
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Check other user knows they're in the group
self.assertIn(group_id, self._get_groups_user_is_in(self.admin_user_tok))
@@ -103,10 +103,10 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
content={"localpart": "test"},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- # Check group returns HTTPStatus.NOT_FOUND
- self._check_group(group_id, expect_code=HTTPStatus.NOT_FOUND)
+ # Check group returns 404
+ self._check_group(group_id, expect_code=404)
# Check users don't think they're in the group
self.assertNotIn(group_id, self._get_groups_user_is_in(self.admin_user_tok))
@@ -122,13 +122,15 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
- self.assertEqual(expect_code, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ )
def _get_groups_user_is_in(self, access_token):
"""Returns the list of groups the user is in (given their access token)"""
channel = self.make_request("GET", b"/joined_groups", access_token=access_token)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
return channel.json_body["groups"]
@@ -208,10 +210,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Should be quarantined
self.assertEqual(
- HTTPStatus.NOT_FOUND,
- channel.code,
+ 404,
+ int(channel.code),
msg=(
- "Expected to receive a HTTPStatus.NOT_FOUND on accessing quarantined media: %s"
+ "Expected to receive a 404 on accessing quarantined media: %s"
% server_and_media_id
),
)
@@ -230,8 +232,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Expect a forbidden error
self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
+ 403,
+ int(channel.result["code"]),
msg="Expected forbidden on quarantining media as a non-admin",
)
@@ -245,8 +247,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Expect a forbidden error
self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
+ 403,
+ int(channel.result["code"]),
msg="Expected forbidden on quarantining media as a non-admin",
)
@@ -277,7 +279,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
)
# Should be successful
- self.assertEqual(HTTPStatus.OK, channel.code)
+ self.assertEqual(200, int(channel.code), msg=channel.result["body"])
# Quarantine the media
url = "/_synapse/admin/v1/media/quarantine/%s/%s" % (
@@ -290,7 +292,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
access_token=admin_user_tok,
)
self.pump(1.0)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.code), msg=channel.result["body"])
# Attempt to access the media
self._ensure_quarantined(admin_user_tok, server_name_and_media_id)
@@ -346,9 +348,11 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
access_token=admin_user_tok,
)
self.pump(1.0)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.code), msg=channel.result["body"])
self.assertEqual(
- channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items"
+ json.loads(channel.result["body"].decode("utf-8")),
+ {"num_quarantined": 2},
+ "Expected 2 quarantined items",
)
# Convert mxc URLs to server/media_id strings
@@ -392,9 +396,11 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
access_token=admin_user_tok,
)
self.pump(1.0)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
- channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items"
+ json.loads(channel.result["body"].decode("utf-8")),
+ {"num_quarantined": 2},
+ "Expected 2 quarantined items",
)
# Attempt to access each piece of media
@@ -426,7 +432,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),)
channel = self.make_request("POST", url, access_token=admin_user_tok)
self.pump(1.0)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.code), msg=channel.result["body"])
# Quarantine all media by this user
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
@@ -438,9 +444,11 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
access_token=admin_user_tok,
)
self.pump(1.0)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
- channel.json_body, {"num_quarantined": 1}, "Expected 1 quarantined item"
+ json.loads(channel.result["body"].decode("utf-8")),
+ {"num_quarantined": 1},
+ "Expected 1 quarantined item",
)
# Attempt to access each piece of media, the first should fail, the
@@ -459,10 +467,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Shouldn't be quarantined
self.assertEqual(
- HTTPStatus.OK,
- channel.code,
+ 200,
+ int(channel.code),
msg=(
- "Expected to receive a HTTPStatus.OK on accessing not-quarantined media: %s"
+ "Expected to receive a 200 on accessing not-quarantined media: %s"
% server_and_media_id_2
),
)
@@ -491,7 +499,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase):
def test_purge_history(self):
"""
Simple test of purge history API.
- Test only that is is possible to call, get status HTTPStatus.OK and purge_id.
+ Test only that is is possible to call, get status 200 and purge_id.
"""
channel = self.make_request(
@@ -501,7 +509,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("purge_id", channel.json_body)
purge_id = channel.json_body["purge_id"]
@@ -512,5 +520,5 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("complete", channel.json_body["status"])
diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index 4d152c0d66..cd5c60b65c 100644
--- a/tests/rest/admin/test_background_updates.py
+++ b/tests/rest/admin/test_background_updates.py
@@ -16,14 +16,11 @@ from typing import Collection
from parameterized import parameterized
-from twisted.test.proto_helpers import MemoryReactor
-
import synapse.rest.admin
from synapse.api.errors import Codes
from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.storage.background_updates import BackgroundUpdater
-from synapse.util import Clock
from tests import unittest
@@ -34,7 +31,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ def prepare(self, reactor, clock, hs: HomeServer):
self.store = hs.get_datastore()
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -47,9 +44,9 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
("POST", "/_synapse/admin/v1/background_updates/start_job"),
]
)
- def test_requester_is_no_admin(self, method: str, url: str) -> None:
+ def test_requester_is_no_admin(self, method: str, url: str):
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
self.register_user("user", "pass", admin=False)
@@ -65,7 +62,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_invalid_parameter(self) -> None:
+ def test_invalid_parameter(self):
"""
If parameters are invalid, an error is returned.
"""
@@ -93,7 +90,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
- def _register_bg_update(self) -> None:
+ def _register_bg_update(self):
"Adds a bg update but doesn't start it"
async def _fake_update(progress, batch_size) -> int:
@@ -115,7 +112,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
)
)
- def test_status_empty(self) -> None:
+ def test_status_empty(self):
"""Test the status API works."""
channel = self.make_request(
@@ -130,7 +127,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
channel.json_body, {"current_updates": {}, "enabled": True}
)
- def test_status_bg_update(self) -> None:
+ def test_status_bg_update(self):
"""Test the status API works with a background update."""
# Create a new background update
@@ -138,7 +135,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
self._register_bg_update()
self.store.db_pool.updates.start_doing_background_updates()
- self.reactor.pump([1.0, 1.0, 1.0])
+ self.reactor.pump([1.0, 1.0])
channel = self.make_request(
"GET",
@@ -165,7 +162,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
},
)
- def test_enabled(self) -> None:
+ def test_enabled(self):
"""Test the enabled API works."""
# Create a new background update
@@ -302,7 +299,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
),
]
)
- def test_start_backround_job(self, job_name: str, updates: Collection[str]) -> None:
+ def test_start_backround_job(self, job_name: str, updates: Collection[str]):
"""
Test that background updates add to database and be processed.
@@ -344,7 +341,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
)
)
- def test_start_backround_job_twice(self) -> None:
+ def test_start_backround_job_twice(self):
"""Test that add a background update twice return an error."""
# add job to database
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index f7080bda87..a3679be205 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -11,18 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import urllib.parse
-from http import HTTPStatus
from parameterized import parameterized
-from twisted.test.proto_helpers import MemoryReactor
-
import synapse.rest.admin
from synapse.api.errors import Codes
from synapse.rest.client import login
-from synapse.server import HomeServer
-from synapse.util import Clock
from tests import unittest
@@ -34,7 +30,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ def prepare(self, reactor, clock, hs):
self.handler = hs.get_device_handler()
self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -51,21 +47,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
)
@parameterized.expand(["GET", "PUT", "DELETE"])
- def test_no_auth(self, method: str) -> None:
+ def test_no_auth(self, method: str):
"""
Try to get a device of an user without authentication.
"""
channel = self.make_request(method, self.url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "PUT", "DELETE"])
- def test_requester_is_no_admin(self, method: str) -> None:
+ def test_requester_is_no_admin(self, method: str):
"""
If the user is not a server admin, an error is returned.
"""
@@ -75,17 +67,13 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "PUT", "DELETE"])
- def test_user_does_not_exist(self, method: str) -> None:
+ def test_user_does_not_exist(self, method: str):
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = (
"/_synapse/admin/v2/users/@unknown_person:test/devices/%s"
@@ -98,13 +86,13 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@parameterized.expand(["GET", "PUT", "DELETE"])
- def test_user_is_not_local(self, method: str) -> None:
+ def test_user_is_not_local(self, method: str):
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = (
"/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices/%s"
@@ -117,12 +105,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
- def test_unknown_device(self) -> None:
+ def test_unknown_device(self):
"""
- Tests that a lookup for a device that does not exist returns either HTTPStatus.NOT_FOUND or HTTPStatus.OK.
+ Tests that a lookup for a device that does not exist returns either 404 or 200.
"""
url = "/_synapse/admin/v2/users/%s/devices/unknown_device" % urllib.parse.quote(
self.other_user
@@ -134,7 +122,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
channel = self.make_request(
@@ -143,7 +131,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
channel = self.make_request(
"DELETE",
@@ -151,10 +139,10 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- # Delete unknown device returns status HTTPStatus.OK
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ # Delete unknown device returns status 200
+ self.assertEqual(200, channel.code, msg=channel.json_body)
- def test_update_device_too_long_display_name(self) -> None:
+ def test_update_device_too_long_display_name(self):
"""
Update a device with a display name that is invalid (too long).
"""
@@ -179,7 +167,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
content=update,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"])
# Ensure the display name was not updated.
@@ -189,12 +177,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new display", channel.json_body["display_name"])
- def test_update_no_display_name(self) -> None:
+ def test_update_no_display_name(self):
"""
- Tests that a update for a device without JSON returns a HTTPStatus.OK
+ Tests that a update for a device without JSON returns a 200
"""
# Set iniital display name.
update = {"display_name": "new display"}
@@ -210,7 +198,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Ensure the display name was not updated.
channel = self.make_request(
@@ -219,10 +207,10 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new display", channel.json_body["display_name"])
- def test_update_display_name(self) -> None:
+ def test_update_display_name(self):
"""
Tests a normal successful update of display name
"""
@@ -234,7 +222,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
content={"display_name": "new displayname"},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check new display_name
channel = self.make_request(
@@ -243,10 +231,10 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new displayname", channel.json_body["display_name"])
- def test_get_device(self) -> None:
+ def test_get_device(self):
"""
Tests that a normal lookup for a device is successfully
"""
@@ -256,7 +244,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
# Check that all fields are available
self.assertIn("user_id", channel.json_body)
@@ -265,7 +253,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.assertIn("last_seen_ip", channel.json_body)
self.assertIn("last_seen_ts", channel.json_body)
- def test_delete_device(self) -> None:
+ def test_delete_device(self):
"""
Tests that a remove of a device is successfully
"""
@@ -281,7 +269,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Ensure that the number of devices is decreased
res = self.get_success(self.handler.get_devices_by_user(self.other_user))
@@ -295,7 +283,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ def prepare(self, reactor, clock, hs):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -305,20 +293,16 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
self.other_user
)
- def test_no_auth(self) -> None:
+ def test_no_auth(self):
"""
Try to list devices of an user without authentication.
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_requester_is_no_admin(self) -> None:
+ def test_requester_is_no_admin(self):
"""
If the user is not a server admin, an error is returned.
"""
@@ -330,16 +314,12 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_user_does_not_exist(self) -> None:
+ def test_user_does_not_exist(self):
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v2/users/@unknown_person:test/devices"
channel = self.make_request(
@@ -348,12 +328,12 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- def test_user_is_not_local(self) -> None:
+ def test_user_is_not_local(self):
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices"
@@ -363,10 +343,10 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
- def test_user_has_no_devices(self) -> None:
+ def test_user_has_no_devices(self):
"""
Tests that a normal lookup for devices is successfully
if user has no devices
@@ -379,11 +359,11 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["devices"]))
- def test_get_devices(self) -> None:
+ def test_get_devices(self):
"""
Tests that a normal lookup for devices is successfully
"""
@@ -399,7 +379,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_devices, channel.json_body["total"])
self.assertEqual(number_devices, len(channel.json_body["devices"]))
self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"])
@@ -419,7 +399,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ def prepare(self, reactor, clock, hs):
self.handler = hs.get_device_handler()
self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -431,20 +411,16 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
self.other_user
)
- def test_no_auth(self) -> None:
+ def test_no_auth(self):
"""
Try to delete devices of an user without authentication.
"""
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_requester_is_no_admin(self) -> None:
+ def test_requester_is_no_admin(self):
"""
If the user is not a server admin, an error is returned.
"""
@@ -456,16 +432,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_user_does_not_exist(self) -> None:
+ def test_user_does_not_exist(self):
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices"
channel = self.make_request(
@@ -474,12 +446,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- def test_user_is_not_local(self) -> None:
+ def test_user_is_not_local(self):
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices"
@@ -489,12 +461,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
- def test_unknown_devices(self) -> None:
+ def test_unknown_devices(self):
"""
- Tests that a remove of a device that does not exist returns HTTPStatus.OK.
+ Tests that a remove of a device that does not exist returns 200.
"""
channel = self.make_request(
"POST",
@@ -503,10 +475,10 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
content={"devices": ["unknown_device1", "unknown_device2"]},
)
- # Delete unknown devices returns status HTTPStatus.OK
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ # Delete unknown devices returns status 200
+ self.assertEqual(200, channel.code, msg=channel.json_body)
- def test_delete_devices(self) -> None:
+ def test_delete_devices(self):
"""
Tests that a remove of devices is successfully
"""
@@ -533,7 +505,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
content={"devices": device_ids},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
res = self.get_success(self.handler.get_devices_by_user(self.other_user))
self.assertEqual(0, len(res))
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index 4f89f8b534..e9ef89731f 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -11,17 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
-from typing import List
-from twisted.test.proto_helpers import MemoryReactor
+import json
import synapse.rest.admin
from synapse.api.errors import Codes
from synapse.rest.client import login, report_event, room
-from synapse.server import HomeServer
-from synapse.types import JsonDict
-from synapse.util import Clock
from tests import unittest
@@ -34,7 +29,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
report_event.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ def prepare(self, reactor, clock, hs):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -75,22 +70,18 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.url = "/_synapse/admin/v1/event_reports"
- def test_no_auth(self) -> None:
+ def test_no_auth(self):
"""
Try to get an event report without authentication.
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_requester_is_no_admin(self) -> None:
+ def test_requester_is_no_admin(self):
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
@@ -99,14 +90,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_default_success(self) -> None:
+ def test_default_success(self):
"""
Testing list of reported events
"""
@@ -117,13 +104,13 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
self.assertNotIn("next_token", channel.json_body)
self._check_fields(channel.json_body["event_reports"])
- def test_limit(self) -> None:
+ def test_limit(self):
"""
Testing list of reported events with limit
"""
@@ -134,13 +121,13 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 5)
self.assertEqual(channel.json_body["next_token"], 5)
self._check_fields(channel.json_body["event_reports"])
- def test_from(self) -> None:
+ def test_from(self):
"""
Testing list of reported events with a defined starting point (from)
"""
@@ -151,13 +138,13 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 15)
self.assertNotIn("next_token", channel.json_body)
self._check_fields(channel.json_body["event_reports"])
- def test_limit_and_from(self) -> None:
+ def test_limit_and_from(self):
"""
Testing list of reported events with a defined starting point and limit
"""
@@ -168,13 +155,13 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(channel.json_body["next_token"], 15)
self.assertEqual(len(channel.json_body["event_reports"]), 10)
self._check_fields(channel.json_body["event_reports"])
- def test_filter_room(self) -> None:
+ def test_filter_room(self):
"""
Testing list of reported events with a filter of room
"""
@@ -185,7 +172,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["event_reports"]), 10)
self.assertNotIn("next_token", channel.json_body)
@@ -194,7 +181,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
for report in channel.json_body["event_reports"]:
self.assertEqual(report["room_id"], self.room_id1)
- def test_filter_user(self) -> None:
+ def test_filter_user(self):
"""
Testing list of reported events with a filter of user
"""
@@ -205,7 +192,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["event_reports"]), 10)
self.assertNotIn("next_token", channel.json_body)
@@ -214,7 +201,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
for report in channel.json_body["event_reports"]:
self.assertEqual(report["user_id"], self.other_user)
- def test_filter_user_and_room(self) -> None:
+ def test_filter_user_and_room(self):
"""
Testing list of reported events with a filter of user and room
"""
@@ -225,7 +212,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 5)
self.assertEqual(len(channel.json_body["event_reports"]), 5)
self.assertNotIn("next_token", channel.json_body)
@@ -235,7 +222,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.assertEqual(report["user_id"], self.other_user)
self.assertEqual(report["room_id"], self.room_id1)
- def test_valid_search_order(self) -> None:
+ def test_valid_search_order(self):
"""
Testing search order. Order by timestamps.
"""
@@ -247,7 +234,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
report = 1
@@ -265,7 +252,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
report = 1
@@ -276,9 +263,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
)
report += 1
- def test_invalid_search_order(self) -> None:
+ def test_invalid_search_order(self):
"""
- Testing that a invalid search order returns a HTTPStatus.BAD_REQUEST
+ Testing that a invalid search order returns a 400
"""
channel = self.make_request(
@@ -287,17 +274,13 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual("Unknown direction: bar", channel.json_body["error"])
- def test_limit_is_negative(self) -> None:
+ def test_limit_is_negative(self):
"""
- Testing that a negative limit parameter returns a HTTPStatus.BAD_REQUEST
+ Testing that a negative limit parameter returns a 400
"""
channel = self.make_request(
@@ -306,16 +289,12 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
- def test_from_is_negative(self) -> None:
+ def test_from_is_negative(self):
"""
- Testing that a negative from parameter returns a HTTPStatus.BAD_REQUEST
+ Testing that a negative from parameter returns a 400
"""
channel = self.make_request(
@@ -324,14 +303,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
- def test_next_token(self) -> None:
+ def test_next_token(self):
"""
Testing that `next_token` appears at the right place
"""
@@ -344,7 +319,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
self.assertNotIn("next_token", channel.json_body)
@@ -357,7 +332,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
self.assertNotIn("next_token", channel.json_body)
@@ -370,7 +345,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 19)
self.assertEqual(channel.json_body["next_token"], 19)
@@ -384,12 +359,12 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 1)
self.assertNotIn("next_token", channel.json_body)
- def _create_event_and_report(self, room_id: str, user_tok: str) -> None:
+ def _create_event_and_report(self, room_id, user_tok):
"""Create and report events"""
resp = self.helper.send(room_id, tok=user_tok)
event_id = resp["event_id"]
@@ -397,14 +372,12 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
"rooms/%s/report/%s" % (room_id, event_id),
- {"score": -100, "reason": "this makes me sad"},
+ json.dumps({"score": -100, "reason": "this makes me sad"}),
access_token=user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- def _create_event_and_report_without_parameters(
- self, room_id: str, user_tok: str
- ) -> None:
+ def _create_event_and_report_without_parameters(self, room_id, user_tok):
"""Create and report an event, but omit reason and score"""
resp = self.helper.send(room_id, tok=user_tok)
event_id = resp["event_id"]
@@ -412,12 +385,12 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
"rooms/%s/report/%s" % (room_id, event_id),
- {},
+ json.dumps({}),
access_token=user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- def _check_fields(self, content: List[JsonDict]) -> None:
+ def _check_fields(self, content):
"""Checks that all attributes are present in an event report"""
for c in content:
self.assertIn("id", c)
@@ -440,7 +413,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
report_event.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ def prepare(self, reactor, clock, hs):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -460,22 +433,18 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
# first created event report gets `id`=2
self.url = "/_synapse/admin/v1/event_reports/2"
- def test_no_auth(self) -> None:
+ def test_no_auth(self):
"""
Try to get event report without authentication.
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_requester_is_no_admin(self) -> None:
+ def test_requester_is_no_admin(self):
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
@@ -484,14 +453,10 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_default_success(self) -> None:
+ def test_default_success(self):
"""
Testing get a reported event
"""
@@ -502,12 +467,12 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self._check_fields(channel.json_body)
- def test_invalid_report_id(self) -> None:
+ def test_invalid_report_id(self):
"""
- Testing that an invalid `report_id` returns a HTTPStatus.BAD_REQUEST.
+ Testing that an invalid `report_id` returns a 400.
"""
# `report_id` is negative
@@ -517,11 +482,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"The report_id parameter must be a string representing a positive integer.",
@@ -535,11 +496,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"The report_id parameter must be a string representing a positive integer.",
@@ -553,20 +510,16 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"The report_id parameter must be a string representing a positive integer.",
channel.json_body["error"],
)
- def test_report_id_not_found(self) -> None:
+ def test_report_id_not_found(self):
"""
- Testing that a not existing `report_id` returns a HTTPStatus.NOT_FOUND.
+ Testing that a not existing `report_id` returns a 404.
"""
channel = self.make_request(
@@ -575,15 +528,11 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.NOT_FOUND,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
self.assertEqual("Event report not found", channel.json_body["error"])
- def _create_event_and_report(self, room_id: str, user_tok: str) -> None:
+ def _create_event_and_report(self, room_id, user_tok):
"""Create and report events"""
resp = self.helper.send(room_id, tok=user_tok)
event_id = resp["event_id"]
@@ -591,12 +540,12 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
"rooms/%s/report/%s" % (room_id, event_id),
- {"score": -100, "reason": "this makes me sad"},
+ json.dumps({"score": -100, "reason": "this makes me sad"}),
access_token=user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- def _check_fields(self, content: JsonDict) -> None:
+ def _check_fields(self, content):
"""Checks that all attributes are present in a event report"""
self.assertIn("id", content)
self.assertIn("received_ts", content)
diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
deleted file mode 100644
index 5188499ef2..0000000000
--- a/tests/rest/admin/test_federation.py
+++ /dev/null
@@ -1,456 +0,0 @@
-# Copyright 2021 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from http import HTTPStatus
-from typing import List, Optional
-
-from parameterized import parameterized
-
-import synapse.rest.admin
-from synapse.api.errors import Codes
-from synapse.rest.client import login
-from synapse.server import HomeServer
-from synapse.types import JsonDict
-
-from tests import unittest
-
-
-class FederationTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- ]
-
- def prepare(self, reactor, clock, hs: HomeServer):
- self.store = hs.get_datastore()
- self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- self.url = "/_synapse/admin/v1/federation/destinations"
-
- @parameterized.expand(
- [
- ("/_synapse/admin/v1/federation/destinations",),
- ("/_synapse/admin/v1/federation/destinations/dummy",),
- ]
- )
- def test_requester_is_no_admin(self, url: str):
- """
- If the user is not a server admin, an error 403 is returned.
- """
-
- self.register_user("user", "pass", admin=False)
- other_user_tok = self.login("user", "pass")
-
- channel = self.make_request(
- "GET",
- url,
- content={},
- access_token=other_user_tok,
- )
-
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- def test_invalid_parameter(self):
- """
- If parameters are invalid, an error is returned.
- """
-
- # negative limit
- channel = self.make_request(
- "GET",
- self.url + "?limit=-5",
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
-
- # negative from
- channel = self.make_request(
- "GET",
- self.url + "?from=-5",
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
-
- # unkown order_by
- channel = self.make_request(
- "GET",
- self.url + "?order_by=bar",
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
-
- # invalid search order
- channel = self.make_request(
- "GET",
- self.url + "?dir=bar",
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
-
- # invalid destination
- channel = self.make_request(
- "GET",
- self.url + "/dummy",
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
- def test_limit(self):
- """
- Testing list of destinations with limit
- """
-
- number_destinations = 20
- self._create_destinations(number_destinations)
-
- channel = self.make_request(
- "GET",
- self.url + "?limit=5",
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["total"], number_destinations)
- self.assertEqual(len(channel.json_body["destinations"]), 5)
- self.assertEqual(channel.json_body["next_token"], "5")
- self._check_fields(channel.json_body["destinations"])
-
- def test_from(self):
- """
- Testing list of destinations with a defined starting point (from)
- """
-
- number_destinations = 20
- self._create_destinations(number_destinations)
-
- channel = self.make_request(
- "GET",
- self.url + "?from=5",
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["total"], number_destinations)
- self.assertEqual(len(channel.json_body["destinations"]), 15)
- self.assertNotIn("next_token", channel.json_body)
- self._check_fields(channel.json_body["destinations"])
-
- def test_limit_and_from(self):
- """
- Testing list of destinations with a defined starting point and limit
- """
-
- number_destinations = 20
- self._create_destinations(number_destinations)
-
- channel = self.make_request(
- "GET",
- self.url + "?from=5&limit=10",
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["total"], number_destinations)
- self.assertEqual(channel.json_body["next_token"], "15")
- self.assertEqual(len(channel.json_body["destinations"]), 10)
- self._check_fields(channel.json_body["destinations"])
-
- def test_next_token(self):
- """
- Testing that `next_token` appears at the right place
- """
-
- number_destinations = 20
- self._create_destinations(number_destinations)
-
- # `next_token` does not appear
- # Number of results is the number of entries
- channel = self.make_request(
- "GET",
- self.url + "?limit=20",
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["total"], number_destinations)
- self.assertEqual(len(channel.json_body["destinations"]), number_destinations)
- self.assertNotIn("next_token", channel.json_body)
-
- # `next_token` does not appear
- # Number of max results is larger than the number of entries
- channel = self.make_request(
- "GET",
- self.url + "?limit=21",
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["total"], number_destinations)
- self.assertEqual(len(channel.json_body["destinations"]), number_destinations)
- self.assertNotIn("next_token", channel.json_body)
-
- # `next_token` does appear
- # Number of max results is smaller than the number of entries
- channel = self.make_request(
- "GET",
- self.url + "?limit=19",
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["total"], number_destinations)
- self.assertEqual(len(channel.json_body["destinations"]), 19)
- self.assertEqual(channel.json_body["next_token"], "19")
-
- # Check
- # Set `from` to value of `next_token` for request remaining entries
- # `next_token` does not appear
- channel = self.make_request(
- "GET",
- self.url + "?from=19",
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["total"], number_destinations)
- self.assertEqual(len(channel.json_body["destinations"]), 1)
- self.assertNotIn("next_token", channel.json_body)
-
- def test_list_all_destinations(self):
- """
- List all destinations.
- """
- number_destinations = 5
- self._create_destinations(number_destinations)
-
- channel = self.make_request(
- "GET",
- self.url,
- {},
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
- self.assertEqual(number_destinations, len(channel.json_body["destinations"]))
- self.assertEqual(number_destinations, channel.json_body["total"])
-
- # Check that all fields are available
- self._check_fields(channel.json_body["destinations"])
-
- def test_order_by(self):
- """
- Testing order list with parameter `order_by`
- """
-
- def _order_test(
- expected_destination_list: List[str],
- order_by: Optional[str],
- dir: Optional[str] = None,
- ):
- """Request the list of destinations in a certain order.
- Assert that order is what we expect
-
- Args:
- expected_destination_list: The list of user_id in the order
- we expect to get back from the server
- order_by: The type of ordering to give the server
- dir: The direction of ordering to give the server
- """
-
- url = f"{self.url}?"
- if order_by is not None:
- url += f"order_by={order_by}&"
- if dir is not None and dir in ("b", "f"):
- url += f"dir={dir}"
- channel = self.make_request(
- "GET",
- url,
- access_token=self.admin_user_tok,
- )
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["total"], len(expected_destination_list))
-
- returned_order = [
- row["destination"] for row in channel.json_body["destinations"]
- ]
- self.assertEqual(expected_destination_list, returned_order)
- self._check_fields(channel.json_body["destinations"])
-
- # create destinations
- dest = [
- ("sub-a.example.com", 100, 300, 200, 300),
- ("sub-b.example.com", 200, 200, 100, 100),
- ("sub-c.example.com", 300, 100, 300, 200),
- ]
- for (
- destination,
- failure_ts,
- retry_last_ts,
- retry_interval,
- last_successful_stream_ordering,
- ) in dest:
- self.get_success(
- self.store.set_destination_retry_timings(
- destination, failure_ts, retry_last_ts, retry_interval
- )
- )
- self.get_success(
- self.store.set_destination_last_successful_stream_ordering(
- destination, last_successful_stream_ordering
- )
- )
-
- # order by default (destination)
- _order_test([dest[0][0], dest[1][0], dest[2][0]], None)
- _order_test([dest[0][0], dest[1][0], dest[2][0]], None, "f")
- _order_test([dest[2][0], dest[1][0], dest[0][0]], None, "b")
-
- # order by destination
- _order_test([dest[0][0], dest[1][0], dest[2][0]], "destination")
- _order_test([dest[0][0], dest[1][0], dest[2][0]], "destination", "f")
- _order_test([dest[2][0], dest[1][0], dest[0][0]], "destination", "b")
-
- # order by failure_ts
- _order_test([dest[0][0], dest[1][0], dest[2][0]], "failure_ts")
- _order_test([dest[0][0], dest[1][0], dest[2][0]], "failure_ts", "f")
- _order_test([dest[2][0], dest[1][0], dest[0][0]], "failure_ts", "b")
-
- # order by retry_last_ts
- _order_test([dest[2][0], dest[1][0], dest[0][0]], "retry_last_ts")
- _order_test([dest[2][0], dest[1][0], dest[0][0]], "retry_last_ts", "f")
- _order_test([dest[0][0], dest[1][0], dest[2][0]], "retry_last_ts", "b")
-
- # order by retry_interval
- _order_test([dest[1][0], dest[0][0], dest[2][0]], "retry_interval")
- _order_test([dest[1][0], dest[0][0], dest[2][0]], "retry_interval", "f")
- _order_test([dest[2][0], dest[0][0], dest[1][0]], "retry_interval", "b")
-
- # order by last_successful_stream_ordering
- _order_test(
- [dest[1][0], dest[2][0], dest[0][0]], "last_successful_stream_ordering"
- )
- _order_test(
- [dest[1][0], dest[2][0], dest[0][0]], "last_successful_stream_ordering", "f"
- )
- _order_test(
- [dest[0][0], dest[2][0], dest[1][0]], "last_successful_stream_ordering", "b"
- )
-
- def test_search_term(self):
- """Test that searching for a destination works correctly"""
-
- def _search_test(
- expected_destination: Optional[str],
- search_term: str,
- ):
- """Search for a destination and check that the returned destinationis a match
-
- Args:
- expected_destination: The room_id expected to be returned by the API.
- Set to None to expect zero results for the search
- search_term: The term to search for room names with
- """
- url = f"{self.url}?destination={search_term}"
- channel = self.make_request(
- "GET",
- url.encode("ascii"),
- access_token=self.admin_user_tok,
- )
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
-
- # Check that destinations were returned
- self.assertTrue("destinations" in channel.json_body)
- self._check_fields(channel.json_body["destinations"])
- destinations = channel.json_body["destinations"]
-
- # Check that the expected number of destinations were returned
- expected_destination_count = 1 if expected_destination else 0
- self.assertEqual(len(destinations), expected_destination_count)
- self.assertEqual(channel.json_body["total"], expected_destination_count)
-
- if expected_destination:
- # Check that the first returned destination is correct
- self.assertEqual(expected_destination, destinations[0]["destination"])
-
- number_destinations = 3
- self._create_destinations(number_destinations)
-
- # Test searching
- _search_test("sub0.example.com", "0")
- _search_test("sub0.example.com", "sub0")
-
- _search_test("sub1.example.com", "1")
- _search_test("sub1.example.com", "1.")
-
- # Test case insensitive
- _search_test("sub0.example.com", "SUB0")
-
- _search_test(None, "foo")
- _search_test(None, "bar")
-
- def test_get_single_destination(self):
- """
- Get one specific destinations.
- """
- self._create_destinations(5)
-
- channel = self.make_request(
- "GET",
- self.url + "/sub0.example.com",
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
- self.assertEqual("sub0.example.com", channel.json_body["destination"])
-
- # Check that all fields are available
- # convert channel.json_body into a List
- self._check_fields([channel.json_body])
-
- def _create_destinations(self, number_destinations: int):
- """Create a number of destinations
-
- Args:
- number_destinations: Number of destinations to be created
- """
- for i in range(0, number_destinations):
- dest = f"sub{i}.example.com"
- self.get_success(self.store.set_destination_retry_timings(dest, 50, 50, 50))
- self.get_success(
- self.store.set_destination_last_successful_stream_ordering(dest, 100)
- )
-
- def _check_fields(self, content: List[JsonDict]):
- """Checks that the expected destination attributes are present in content
-
- Args:
- content: List that is checked for content
- """
- for c in content:
- self.assertIn("destination", c)
- self.assertIn("retry_last_ts", c)
- self.assertIn("retry_interval", c)
- self.assertIn("failure_ts", c)
- self.assertIn("last_successful_stream_ordering", c)
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 81e578fd26..db0e78c039 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -12,19 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+import json
import os
-from http import HTTPStatus
from parameterized import parameterized
-from twisted.test.proto_helpers import MemoryReactor
-
import synapse.rest.admin
from synapse.api.errors import Codes
from synapse.rest.client import login, profile, room
from synapse.rest.media.v1.filepath import MediaFilePaths
-from synapse.server import HomeServer
-from synapse.util import Clock
from tests import unittest
from tests.server import FakeSite, make_request
@@ -42,7 +39,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ def prepare(self, reactor, clock, hs):
self.media_repo = hs.get_media_repository_resource()
self.server_name = hs.hostname
@@ -51,7 +48,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
- def test_no_auth(self) -> None:
+ def test_no_auth(self):
"""
Try to delete media without authentication.
"""
@@ -59,14 +56,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
channel = self.make_request("DELETE", url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_requester_is_no_admin(self) -> None:
+ def test_requester_is_no_admin(self):
"""
If the user is not a server admin, an error is returned.
"""
@@ -81,16 +74,12 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_media_does_not_exist(self) -> None:
+ def test_media_does_not_exist(self):
"""
- Tests that a lookup for a media that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a media that does not exist returns a 404
"""
url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
@@ -100,12 +89,12 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- def test_media_is_not_local(self) -> None:
+ def test_media_is_not_local(self):
"""
- Tests that a lookup for a media that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a media that is not a local returns a 400
"""
url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345")
@@ -115,10 +104,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only delete local media", channel.json_body["error"])
- def test_delete_media(self) -> None:
+ def test_delete_media(self):
"""
Tests that delete a media is successfully
"""
@@ -128,10 +117,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource,
- SMALL_PNG,
- tok=self.admin_user_tok,
- expect_code=HTTPStatus.OK,
+ upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -151,11 +137,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
# Should be successful
self.assertEqual(
- HTTPStatus.OK,
+ 200,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.OK on accessing media: %s"
- % server_and_media_id
+ "Expected to receive a 200 on accessing media: %s" % server_and_media_id
),
)
@@ -172,7 +157,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
media_id,
@@ -189,10 +174,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
self.assertEqual(
- HTTPStatus.NOT_FOUND,
+ 404,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s"
+ "Expected to receive a 404 on accessing deleted media: %s"
% server_and_media_id
),
)
@@ -211,7 +196,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ def prepare(self, reactor, clock, hs):
self.media_repo = hs.get_media_repository_resource()
self.server_name = hs.hostname
@@ -224,21 +209,17 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
# Move clock up to somewhat realistic time
self.reactor.advance(1000000000)
- def test_no_auth(self) -> None:
+ def test_no_auth(self):
"""
Try to delete media without authentication.
"""
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_requester_is_no_admin(self) -> None:
+ def test_requester_is_no_admin(self):
"""
If the user is not a server admin, an error is returned.
"""
@@ -251,16 +232,12 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_media_is_not_local(self) -> None:
+ def test_media_is_not_local(self):
"""
- Tests that a lookup for media that is not local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for media that is not local returns a 400
"""
url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain"
@@ -270,10 +247,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only delete local media", channel.json_body["error"])
- def test_missing_parameter(self) -> None:
+ def test_missing_parameter(self):
"""
If the parameter `before_ts` is missing, an error is returned.
"""
@@ -283,17 +260,13 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Missing integer query parameter 'before_ts'", channel.json_body["error"]
)
- def test_invalid_parameter(self) -> None:
+ def test_invalid_parameter(self):
"""
If parameters are invalid, an error is returned.
"""
@@ -303,11 +276,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter before_ts must be a positive integer.",
@@ -320,11 +289,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter before_ts you provided is from the year 1970. "
@@ -338,11 +303,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter size_gt must be a string representing a positive integer.",
@@ -355,18 +316,14 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual(
"Boolean query parameter 'keep_profiles' must be one of ['true', 'false']",
channel.json_body["error"],
)
- def test_delete_media_never_accessed(self) -> None:
+ def test_delete_media_never_accessed(self):
"""
Tests that media deleted if it is older than `before_ts` and never accessed
`last_access_ts` is `NULL` and `created_ts` < `before_ts`
@@ -388,7 +345,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
media_id,
@@ -397,7 +354,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self._access_media(server_and_media_id, False)
- def test_keep_media_by_date(self) -> None:
+ def test_keep_media_by_date(self):
"""
Tests that media is not deleted if it is newer than `before_ts`
"""
@@ -413,7 +370,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self._access_media(server_and_media_id)
@@ -425,7 +382,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
server_and_media_id.split("/")[1],
@@ -434,7 +391,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self._access_media(server_and_media_id, False)
- def test_keep_media_by_size(self) -> None:
+ def test_keep_media_by_size(self):
"""
Tests that media is not deleted if its size is smaller than or equal
to `size_gt`
@@ -449,7 +406,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&size_gt=67",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self._access_media(server_and_media_id)
@@ -460,7 +417,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&size_gt=66",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
server_and_media_id.split("/")[1],
@@ -469,7 +426,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self._access_media(server_and_media_id, False)
- def test_keep_media_by_user_avatar(self) -> None:
+ def test_keep_media_by_user_avatar(self):
"""
Tests that we do not delete media if is used as a user avatar
Tests parameter `keep_profiles`
@@ -482,10 +439,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
"/profile/%s/avatar_url" % (self.admin_user,),
- content={"avatar_url": "mxc://%s" % (server_and_media_id,)},
+ content=json.dumps({"avatar_url": "mxc://%s" % (server_and_media_id,)}),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
now_ms = self.clock.time_msec()
channel = self.make_request(
@@ -493,7 +450,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self._access_media(server_and_media_id)
@@ -504,7 +461,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
server_and_media_id.split("/")[1],
@@ -513,7 +470,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self._access_media(server_and_media_id, False)
- def test_keep_media_by_room_avatar(self) -> None:
+ def test_keep_media_by_room_avatar(self):
"""
Tests that we do not delete media if it is used as a room avatar
Tests parameter `keep_profiles`
@@ -527,10 +484,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
"/rooms/%s/state/m.room.avatar" % (room_id,),
- content={"url": "mxc://%s" % (server_and_media_id,)},
+ content=json.dumps({"url": "mxc://%s" % (server_and_media_id,)}),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
now_ms = self.clock.time_msec()
channel = self.make_request(
@@ -538,7 +495,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self._access_media(server_and_media_id)
@@ -549,7 +506,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
server_and_media_id.split("/")[1],
@@ -558,7 +515,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self._access_media(server_and_media_id, False)
- def _create_media(self) -> str:
+ def _create_media(self):
"""
Create a media and return media_id and server_and_media_id
"""
@@ -566,10 +523,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource,
- SMALL_PNG,
- tok=self.admin_user_tok,
- expect_code=HTTPStatus.OK,
+ upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -580,7 +534,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
return server_and_media_id
- def _access_media(self, server_and_media_id, expect_success=True) -> None:
+ def _access_media(self, server_and_media_id, expect_success=True):
"""
Try to access a media and check the result
"""
@@ -600,10 +554,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
if expect_success:
self.assertEqual(
- HTTPStatus.OK,
+ 200,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.OK on accessing media: %s"
+ "Expected to receive a 200 on accessing media: %s"
% server_and_media_id
),
)
@@ -611,10 +565,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.assertTrue(os.path.exists(local_path))
else:
self.assertEqual(
- HTTPStatus.NOT_FOUND,
+ 404,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s"
+ "Expected to receive a 404 on accessing deleted media: %s"
% (server_and_media_id)
),
)
@@ -630,7 +584,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ def prepare(self, reactor, clock, hs):
media_repo = hs.get_media_repository_resource()
self.store = hs.get_datastore()
self.server_name = hs.hostname
@@ -643,10 +597,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource,
- SMALL_PNG,
- tok=self.admin_user_tok,
- expect_code=HTTPStatus.OK,
+ upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -655,7 +606,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
self.url = "/_synapse/admin/v1/media/%s/%s/%s"
@parameterized.expand(["quarantine", "unquarantine"])
- def test_no_auth(self, action: str) -> None:
+ def test_no_auth(self, action: str):
"""
Try to protect media without authentication.
"""
@@ -666,15 +617,11 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
b"{}",
)
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["quarantine", "unquarantine"])
- def test_requester_is_no_admin(self, action: str) -> None:
+ def test_requester_is_no_admin(self, action: str):
"""
If the user is not a server admin, an error is returned.
"""
@@ -687,14 +634,10 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_quarantine_media(self) -> None:
+ def test_quarantine_media(self):
"""
Tests that quarantining and remove from quarantine a media is successfully
"""
@@ -709,7 +652,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
@@ -722,13 +665,13 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
self.assertFalse(media_info["quarantined_by"])
- def test_quarantine_protected_media(self) -> None:
+ def test_quarantine_protected_media(self):
"""
Tests that quarantining from protected media fails
"""
@@ -747,7 +690,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
# verify that is not in quarantine
@@ -763,7 +706,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ def prepare(self, reactor, clock, hs):
media_repo = hs.get_media_repository_resource()
self.store = hs.get_datastore()
@@ -775,10 +718,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource,
- SMALL_PNG,
- tok=self.admin_user_tok,
- expect_code=HTTPStatus.OK,
+ upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -787,22 +727,18 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
self.url = "/_synapse/admin/v1/media/%s/%s"
@parameterized.expand(["protect", "unprotect"])
- def test_no_auth(self, action: str) -> None:
+ def test_no_auth(self, action: str):
"""
Try to protect media without authentication.
"""
channel = self.make_request("POST", self.url % (action, self.media_id), b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["protect", "unprotect"])
- def test_requester_is_no_admin(self, action: str) -> None:
+ def test_requester_is_no_admin(self, action: str):
"""
If the user is not a server admin, an error is returned.
"""
@@ -815,14 +751,10 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_protect_media(self) -> None:
+ def test_protect_media(self):
"""
Tests that protect and unprotect a media is successfully
"""
@@ -837,7 +769,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
@@ -850,7 +782,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
@@ -867,7 +799,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ def prepare(self, reactor, clock, hs):
self.media_repo = hs.get_media_repository_resource()
self.server_name = hs.hostname
@@ -877,21 +809,17 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
self.url = "/_synapse/admin/v1/purge_media_cache"
- def test_no_auth(self) -> None:
+ def test_no_auth(self):
"""
Try to delete media without authentication.
"""
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_requester_is_not_admin(self) -> None:
+ def test_requester_is_not_admin(self):
"""
If the user is not a server admin, an error is returned.
"""
@@ -904,14 +832,10 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_invalid_parameter(self) -> None:
+ def test_invalid_parameter(self):
"""
If parameters are invalid, an error is returned.
"""
@@ -921,11 +845,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter before_ts must be a positive integer.",
@@ -938,11 +858,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter before_ts you provided is from the year 1970. "
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
index 350a62dda6..9bac423ae0 100644
--- a/tests/rest/admin/test_registration_tokens.py
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -11,17 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import random
import string
-from http import HTTPStatus
-
-from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.errors import Codes
from synapse.rest.client import login
-from synapse.server import HomeServer
-from synapse.util import Clock
from tests import unittest
@@ -32,7 +28,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -42,7 +38,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.url = "/_synapse/admin/v1/registration_tokens"
- def _new_token(self, **kwargs) -> str:
+ def _new_token(self, **kwargs):
"""Helper function to create a token."""
token = kwargs.get(
"token",
@@ -64,17 +60,13 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
# CREATION
- def test_create_no_auth(self) -> None:
+ def test_create_no_auth(self):
"""Try to create a token without authentication."""
channel = self.make_request("POST", self.url + "/new", {})
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_create_requester_not_admin(self) -> None:
+ def test_create_requester_not_admin(self):
"""Try to create a token while not an admin."""
channel = self.make_request(
"POST",
@@ -82,14 +74,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_create_using_defaults(self) -> None:
+ def test_create_using_defaults(self):
"""Create a token using all the defaults."""
channel = self.make_request(
"POST",
@@ -98,14 +86,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(len(channel.json_body["token"]), 16)
self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"])
self.assertEqual(channel.json_body["pending"], 0)
self.assertEqual(channel.json_body["completed"], 0)
- def test_create_specifying_fields(self) -> None:
+ def test_create_specifying_fields(self):
"""Create a token specifying the value of all fields."""
# As many of the allowed characters as possible with length <= 64
token = "adefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._~-"
@@ -122,14 +110,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["token"], token)
self.assertEqual(channel.json_body["uses_allowed"], 1)
self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"])
self.assertEqual(channel.json_body["pending"], 0)
self.assertEqual(channel.json_body["completed"], 0)
- def test_create_with_null_value(self) -> None:
+ def test_create_with_null_value(self):
"""Create a token specifying unlimited uses and no expiry."""
data = {
"uses_allowed": None,
@@ -143,14 +131,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(len(channel.json_body["token"]), 16)
self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"])
self.assertEqual(channel.json_body["pending"], 0)
self.assertEqual(channel.json_body["completed"], 0)
- def test_create_token_too_long(self) -> None:
+ def test_create_token_too_long(self):
"""Check token longer than 64 chars is invalid."""
data = {"token": "a" * 65}
@@ -161,14 +149,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
- def test_create_token_invalid_chars(self) -> None:
+ def test_create_token_invalid_chars(self):
"""Check you can't create token with invalid characters."""
data = {
"token": "abc/def",
@@ -181,14 +165,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
- def test_create_token_already_exists(self) -> None:
+ def test_create_token_already_exists(self):
"""Check you can't create token that already exists."""
data = {
"token": "abcd",
@@ -200,7 +180,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
data,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel1.code, msg=channel1.json_body)
+ self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"])
channel2 = self.make_request(
"POST",
@@ -208,10 +188,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
data,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel2.code, msg=channel2.json_body)
+ self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"])
self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM)
- def test_create_unable_to_generate_token(self) -> None:
+ def test_create_unable_to_generate_token(self):
"""Check right error is raised when server can't generate unique token."""
# Create all possible single character tokens
tokens = []
@@ -240,9 +220,9 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 1},
access_token=self.admin_user_tok,
)
- self.assertEqual(500, channel.code, msg=channel.json_body)
+ self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"])
- def test_create_uses_allowed(self) -> None:
+ def test_create_uses_allowed(self):
"""Check you can only create a token with good values for uses_allowed."""
# Should work with 0 (token is invalid from the start)
channel = self.make_request(
@@ -251,7 +231,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 0},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["uses_allowed"], 0)
# Should fail with negative integer
@@ -261,11 +241,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": -5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with float
@@ -275,14 +251,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 1.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
- def test_create_expiry_time(self) -> None:
+ def test_create_expiry_time(self):
"""Check you can't create a token with an invalid expiry_time."""
# Should fail with a time in the past
channel = self.make_request(
@@ -291,11 +263,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": self.clock.time_msec() - 10000},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with float
@@ -305,14 +273,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": self.clock.time_msec() + 1000000.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
- def test_create_length(self) -> None:
+ def test_create_length(self):
"""Check you can only generate a token with a valid length."""
# Should work with 64
channel = self.make_request(
@@ -321,7 +285,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 64},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(len(channel.json_body["token"]), 64)
# Should fail with 0
@@ -331,11 +295,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 0},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with a negative integer
@@ -345,11 +305,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": -5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with a float
@@ -359,11 +315,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 8.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with 65
@@ -373,30 +325,22 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 65},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# UPDATING
- def test_update_no_auth(self) -> None:
+ def test_update_no_auth(self):
"""Try to update a token without authentication."""
channel = self.make_request(
"PUT",
self.url + "/1234", # Token doesn't exist but that doesn't matter
{},
)
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_update_requester_not_admin(self) -> None:
+ def test_update_requester_not_admin(self):
"""Try to update a token while not an admin."""
channel = self.make_request(
"PUT",
@@ -404,14 +348,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_update_non_existent(self) -> None:
+ def test_update_non_existent(self):
"""Try to update a token that doesn't exist."""
channel = self.make_request(
"PUT",
@@ -420,14 +360,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.NOT_FOUND,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
- def test_update_uses_allowed(self) -> None:
+ def test_update_uses_allowed(self):
"""Test updating just uses_allowed."""
# Create new token using default values
token = self._new_token()
@@ -439,7 +375,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 1},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["uses_allowed"], 1)
self.assertIsNone(channel.json_body["expiry_time"])
@@ -450,7 +386,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 0},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["uses_allowed"], 0)
self.assertIsNone(channel.json_body["expiry_time"])
@@ -461,7 +397,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": None},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"])
@@ -472,11 +408,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 1.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with a negative integer
@@ -486,14 +418,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": -5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
- def test_update_expiry_time(self) -> None:
+ def test_update_expiry_time(self):
"""Test updating just expiry_time."""
# Create new token using default values
token = self._new_token()
@@ -506,7 +434,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": new_expiry_time},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
self.assertIsNone(channel.json_body["uses_allowed"])
@@ -517,7 +445,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": None},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertIsNone(channel.json_body["expiry_time"])
self.assertIsNone(channel.json_body["uses_allowed"])
@@ -529,11 +457,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": past_time},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail a float
@@ -543,14 +467,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": new_expiry_time + 0.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
- def test_update_both(self) -> None:
+ def test_update_both(self):
"""Test updating both uses_allowed and expiry_time."""
# Create new token using default values
token = self._new_token()
@@ -568,11 +488,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["uses_allowed"], 1)
self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
- def test_update_invalid_type(self) -> None:
+ def test_update_invalid_type(self):
"""Test using invalid types doesn't work."""
# Create new token using default values
token = self._new_token()
@@ -589,30 +509,22 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# DELETING
- def test_delete_no_auth(self) -> None:
+ def test_delete_no_auth(self):
"""Try to delete a token without authentication."""
channel = self.make_request(
"DELETE",
self.url + "/1234", # Token doesn't exist but that doesn't matter
{},
)
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_delete_requester_not_admin(self) -> None:
+ def test_delete_requester_not_admin(self):
"""Try to delete a token while not an admin."""
channel = self.make_request(
"DELETE",
@@ -620,14 +532,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_delete_non_existent(self) -> None:
+ def test_delete_non_existent(self):
"""Try to delete a token that doesn't exist."""
channel = self.make_request(
"DELETE",
@@ -636,14 +544,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.NOT_FOUND,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
- def test_delete(self) -> None:
+ def test_delete(self):
"""Test deleting a token."""
# Create new token using default values
token = self._new_token()
@@ -655,25 +559,21 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# GETTING ONE
- def test_get_no_auth(self) -> None:
+ def test_get_no_auth(self):
"""Try to get a token without authentication."""
channel = self.make_request(
"GET",
self.url + "/1234", # Token doesn't exist but that doesn't matter
{},
)
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_get_requester_not_admin(self) -> None:
+ def test_get_requester_not_admin(self):
"""Try to get a token while not an admin."""
channel = self.make_request(
"GET",
@@ -681,14 +581,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_get_non_existent(self) -> None:
+ def test_get_non_existent(self):
"""Try to get a token that doesn't exist."""
channel = self.make_request(
"GET",
@@ -697,14 +593,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.NOT_FOUND,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
- def test_get(self) -> None:
+ def test_get(self):
"""Test getting a token."""
# Create new token using default values
token = self._new_token()
@@ -716,7 +608,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["token"], token)
self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"])
@@ -725,17 +617,13 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
# LISTING
- def test_list_no_auth(self) -> None:
+ def test_list_no_auth(self):
"""Try to list tokens without authentication."""
channel = self.make_request("GET", self.url, {})
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_list_requester_not_admin(self) -> None:
+ def test_list_requester_not_admin(self):
"""Try to list tokens while not an admin."""
channel = self.make_request(
"GET",
@@ -743,14 +631,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_list_all(self) -> None:
+ def test_list_all(self):
"""Test listing all tokens."""
# Create new token using default values
token = self._new_token()
@@ -762,7 +646,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(len(channel.json_body["registration_tokens"]), 1)
token_info = channel.json_body["registration_tokens"][0]
self.assertEqual(token_info["token"], token)
@@ -771,7 +655,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.assertEqual(token_info["pending"], 0)
self.assertEqual(token_info["completed"], 0)
- def test_list_invalid_query_parameter(self) -> None:
+ def test_list_invalid_query_parameter(self):
"""Test with `valid` query parameter not `true` or `false`."""
channel = self.make_request(
"GET",
@@ -780,13 +664,9 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- def _test_list_query_parameter(self, valid: str) -> None:
+ def _test_list_query_parameter(self, valid: str):
"""Helper used to test both valid=true and valid=false."""
# Create 2 valid and 2 invalid tokens.
now = self.hs.get_clock().time_msec()
@@ -816,17 +696,17 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(len(channel.json_body["registration_tokens"]), 2)
token_info_1 = channel.json_body["registration_tokens"][0]
token_info_2 = channel.json_body["registration_tokens"][1]
self.assertIn(token_info_1["token"], tokens)
self.assertIn(token_info_2["token"], tokens)
- def test_list_valid(self) -> None:
+ def test_list_valid(self):
"""Test listing just valid tokens."""
self._test_list_query_parameter(valid="true")
- def test_list_invalid(self) -> None:
+ def test_list_invalid(self):
"""Test listing just invalid tokens."""
self._test_list_query_parameter(valid="false")
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 22f9aa6234..07077aff78 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+import json
import urllib.parse
from http import HTTPStatus
from typing import List, Optional
@@ -18,15 +20,11 @@ from unittest.mock import Mock
from parameterized import parameterized
-from twisted.test.proto_helpers import MemoryReactor
-
import synapse.rest.admin
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import Codes
from synapse.handlers.pagination import PaginationHandler
from synapse.rest.client import directory, events, login, room
-from synapse.server import HomeServer
-from synapse.util import Clock
from tests import unittest
@@ -42,7 +40,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
room.register_deprecated_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ def prepare(self, reactor, clock, hs):
self.event_creation_handler = hs.get_event_creation_handler()
hs.config.consent.user_consent_version = "1"
@@ -68,7 +66,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
def test_requester_is_no_admin(self):
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
@@ -78,12 +76,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_room_does_not_exist(self):
"""
- Check that unknown rooms/server return 200
+ Check that unknown rooms/server return error 404.
"""
url = "/_synapse/admin/v1/rooms/%s" % "!unknown:test"
@@ -94,11 +92,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_room_is_not_valid(self):
"""
- Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
+ Check that invalid room names, return an error 400.
"""
url = "/_synapse/admin/v1/rooms/%s" % "invalidroom"
@@ -109,7 +108,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom is not a legal room ID",
channel.json_body["error"],
@@ -119,15 +118,16 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
Tests that the user ID must be from local server but it does not have to exist.
"""
+ body = json.dumps({"new_room_user_id": "@unknown:test"})
channel = self.make_request(
"DELETE",
self.url,
- content={"new_room_user_id": "@unknown:test"},
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("new_room_id", channel.json_body)
self.assertIn("kicked_users", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -137,15 +137,16 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
Check that only local users can create new room to move members.
"""
+ body = json.dumps({"new_room_user_id": "@not:exist.bla"})
channel = self.make_request(
"DELETE",
self.url,
- content={"new_room_user_id": "@not:exist.bla"},
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"User must be our own: @not:exist.bla",
channel.json_body["error"],
@@ -155,30 +156,32 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"""
If parameter `block` is not boolean, return an error
"""
+ body = json.dumps({"block": "NotBool"})
channel = self.make_request(
"DELETE",
self.url,
- content={"block": "NotBool"},
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_purge_is_not_bool(self):
"""
If parameter `purge` is not boolean, return an error
"""
+ body = json.dumps({"purge": "NotBool"})
channel = self.make_request(
"DELETE",
self.url,
- content={"purge": "NotBool"},
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_purge_room_and_block(self):
@@ -195,14 +198,16 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
# Assert one user in room
self._is_member(room_id=self.room_id, user_id=self.other_user)
+ body = json.dumps({"block": True, "purge": True})
+
channel = self.make_request(
"DELETE",
self.url.encode("ascii"),
- content={"block": True, "purge": True},
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -226,14 +231,16 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
# Assert one user in room
self._is_member(room_id=self.room_id, user_id=self.other_user)
+ body = json.dumps({"block": False, "purge": True})
+
channel = self.make_request(
"DELETE",
self.url.encode("ascii"),
- content={"block": False, "purge": True},
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -258,14 +265,16 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
# Assert one user in room
self._is_member(room_id=self.room_id, user_id=self.other_user)
+ body = json.dumps({"block": True, "purge": False})
+
channel = self.make_request(
"DELETE",
self.url.encode("ascii"),
- content={"block": True, "purge": False},
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -296,7 +305,9 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
)
# The room is now blocked.
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ HTTPStatus.OK, int(channel.result["code"]), msg=channel.result["body"]
+ )
self._is_blocked(room_id)
def test_shutdown_room_consent(self):
@@ -316,10 +327,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
# Assert that the user is getting consent error
self.helper.send(
- self.room_id,
- body="foo",
- tok=self.other_user_tok,
- expect_code=HTTPStatus.FORBIDDEN,
+ self.room_id, body="foo", tok=self.other_user_tok, expect_code=403
)
# Test that room is not purged
@@ -333,11 +341,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url,
- {"new_room_user_id": self.admin_user},
+ json.dumps({"new_room_user_id": self.admin_user}),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("new_room_id", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -363,10 +371,10 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
url.encode("ascii"),
- {"history_visibility": "world_readable"},
+ json.dumps({"history_visibility": "world_readable"}),
access_token=self.other_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Test that room is not purged
with self.assertRaises(AssertionError):
@@ -379,11 +387,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url,
- {"new_room_user_id": self.admin_user},
+ json.dumps({"new_room_user_id": self.admin_user}),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("new_room_id", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -398,7 +406,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self._has_no_members(self.room_id)
# Assert we can no longer peek into the room
- self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN)
+ self._assert_peek(self.room_id, expect_code=403)
def _is_blocked(self, room_id, expect=True):
"""Assert that the room is blocked or not"""
@@ -457,7 +465,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
room.register_deprecated_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ def prepare(self, reactor, clock, hs):
self.event_creation_handler = hs.get_event_creation_handler()
hs.config.consent.user_consent_version = "1"
@@ -494,7 +502,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
)
def test_requester_is_no_admin(self, method: str, url: str):
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
@@ -507,36 +515,27 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_room_does_not_exist(self):
+ @parameterized.expand(
+ [
+ ("DELETE", "/_synapse/admin/v2/rooms/%s"),
+ ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"),
+ ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"),
+ ]
+ )
+ def test_room_does_not_exist(self, method: str, url: str):
"""
- Check that unknown rooms/server return 200
-
- This is important, as it allows incomplete vestiges of rooms to be cleared up
- even if the create event/etc is missing.
+ Check that unknown rooms/server return error 404.
"""
- room_id = "!unknown:test"
- channel = self.make_request(
- "DELETE",
- f"/_synapse/admin/v2/rooms/{room_id}",
- content={},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
- self.assertIn("delete_id", channel.json_body)
- delete_id = channel.json_body["delete_id"]
-
- # get status
channel = self.make_request(
- "GET",
- f"/_synapse/admin/v2/rooms/{room_id}/delete_status",
+ method,
+ url % "!unknown:test",
+ content={},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
- self.assertEqual(1, len(channel.json_body["results"]))
- self.assertEqual("complete", channel.json_body["results"][0]["status"])
- self.assertEqual(delete_id, channel.json_body["results"][0]["delete_id"])
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@parameterized.expand(
[
@@ -546,7 +545,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
)
def test_room_is_not_valid(self, method: str, url: str):
"""
- Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
+ Check that invalid room names, return an error 400.
"""
channel = self.make_request(
@@ -855,10 +854,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
# Assert that the user is getting consent error
self.helper.send(
- self.room_id,
- body="foo",
- tok=self.other_user_tok,
- expect_code=HTTPStatus.FORBIDDEN,
+ self.room_id, body="foo", tok=self.other_user_tok, expect_code=403
)
# Test that room is not purged
@@ -955,7 +951,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self._has_no_members(self.room_id)
# Assert we can no longer peek into the room
- self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN)
+ self._assert_peek(self.room_id, expect_code=403)
def _is_blocked(self, room_id: str, expect: bool = True) -> None:
"""Assert that the room is blocked or not"""
@@ -1073,12 +1069,12 @@ class RoomTestCase(unittest.HomeserverTestCase):
directory.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ def prepare(self, reactor, clock, hs):
# Create user
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
- def test_list_rooms(self) -> None:
+ def test_list_rooms(self):
"""Test that we can list rooms"""
# Create 3 test rooms
total_rooms = 3
@@ -1098,7 +1094,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
# Check request completed successfully
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that response json body contains a "rooms" key
self.assertTrue(
@@ -1142,7 +1138,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
# We shouldn't receive a next token here as there's no further rooms to show
self.assertNotIn("next_batch", channel.json_body)
- def test_list_rooms_pagination(self) -> None:
+ def test_list_rooms_pagination(self):
"""Test that we can get a full list of rooms through pagination"""
# Create 5 test rooms
total_rooms = 5
@@ -1182,7 +1178,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue("rooms" in channel.json_body)
for r in channel.json_body["rooms"]:
@@ -1222,9 +1218,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
- def test_correct_room_attributes(self) -> None:
+ def test_correct_room_attributes(self):
"""Test the correct attributes for a room are returned"""
# Create a test room
room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
@@ -1245,7 +1241,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
{"room_id": room_id},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Set this new alias as the canonical alias for this room
self.helper.send_state(
@@ -1277,7 +1273,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that rooms were returned
self.assertTrue("rooms" in channel.json_body)
@@ -1305,7 +1301,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(test_room_name, r["name"])
self.assertEqual(test_alias, r["canonical_alias"])
- def test_room_list_sort_order(self) -> None:
+ def test_room_list_sort_order(self):
"""Test room list sort ordering. alphabetical name versus number of members,
reversing the order, etc.
"""
@@ -1314,7 +1310,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
order_type: str,
expected_room_list: List[str],
reverse: bool = False,
- ) -> None:
+ ):
"""Request the list of rooms in a certain order. Assert that order is what
we expect
@@ -1332,7 +1328,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that rooms were returned
self.assertTrue("rooms" in channel.json_body)
@@ -1443,7 +1439,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
_order_test("state_events", [room_id_3, room_id_2, room_id_1])
_order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True)
- def test_search_term(self) -> None:
+ def test_search_term(self):
"""Test that searching for a room works correctly"""
# Create two test rooms
room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
@@ -1471,8 +1467,8 @@ class RoomTestCase(unittest.HomeserverTestCase):
def _search_test(
expected_room_id: Optional[str],
search_term: str,
- expected_http_code: int = HTTPStatus.OK,
- ) -> None:
+ expected_http_code: int = 200,
+ ):
"""Search for a room and check that the returned room's id is a match
Args:
@@ -1489,7 +1485,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
- if expected_http_code != HTTPStatus.OK:
+ if expected_http_code != 200:
return
# Check that rooms were returned
@@ -1532,7 +1528,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
_search_test(None, "foo")
_search_test(None, "bar")
- _search_test(None, "", expected_http_code=HTTPStatus.BAD_REQUEST)
+ _search_test(None, "", expected_http_code=400)
# Test that the whole room id returns the room
_search_test(room_id_1, room_id_1)
@@ -1546,7 +1542,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Test search local part of alias
_search_test(room_id_1, "alias1")
- def test_search_term_non_ascii(self) -> None:
+ def test_search_term_non_ascii(self):
"""Test that searching for a room with non-ASCII characters works correctly"""
# Create test room
@@ -1569,11 +1565,11 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(room_id, channel.json_body.get("rooms")[0].get("room_id"))
self.assertEqual("ж", channel.json_body.get("rooms")[0].get("name"))
- def test_single_room(self) -> None:
+ def test_single_room(self):
"""Test that a single room can be requested correctly"""
# Create two test rooms
room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
@@ -1602,7 +1598,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("room_id", channel.json_body)
self.assertIn("name", channel.json_body)
@@ -1624,7 +1620,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(room_id_1, channel.json_body["room_id"])
- def test_single_room_devices(self) -> None:
+ def test_single_room_devices(self):
"""Test that `joined_local_devices` can be requested correctly"""
room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
@@ -1634,7 +1630,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["joined_local_devices"])
# Have another user join the room
@@ -1648,7 +1644,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(2, channel.json_body["joined_local_devices"])
# leave room
@@ -1660,10 +1656,10 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["joined_local_devices"])
- def test_room_members(self) -> None:
+ def test_room_members(self):
"""Test that room members can be requested correctly"""
# Create two test rooms
room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
@@ -1691,7 +1687,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertCountEqual(
["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"]
@@ -1704,14 +1700,14 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertCountEqual(
["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"]
)
self.assertEqual(channel.json_body["total"], 3)
- def test_room_state(self) -> None:
+ def test_room_state(self):
"""Test that room state can be requested correctly"""
# Create two test rooms
room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
@@ -1722,15 +1718,13 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("state", channel.json_body)
# testing that the state events match is painful and not done here. We assume that
# the create_room already does the right thing, so no need to verify that we got
# the state events it created.
- def _set_canonical_alias(
- self, room_id: str, test_alias: str, admin_user_tok: str
- ) -> None:
+ def _set_canonical_alias(self, room_id: str, test_alias: str, admin_user_tok: str):
# Create a new alias to this room
url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
channel = self.make_request(
@@ -1739,7 +1733,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
{"room_id": room_id},
access_token=admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Set this new alias as the canonical alias for this room
self.helper.send_state(
@@ -1765,7 +1759,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ def prepare(self, reactor, clock, homeserver):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -1780,117 +1774,124 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
)
self.url = f"/_synapse/admin/v1/join/{self.public_room_id}"
- def test_requester_is_no_admin(self) -> None:
+ def test_requester_is_no_admin(self):
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
+ body = json.dumps({"user_id": self.second_user_id})
channel = self.make_request(
"POST",
self.url,
- content={"user_id": self.second_user_id},
+ content=body,
access_token=self.second_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_invalid_parameter(self) -> None:
+ def test_invalid_parameter(self):
"""
If a parameter is missing, return an error
"""
+ body = json.dumps({"unknown_parameter": "@unknown:test"})
channel = self.make_request(
"POST",
self.url,
- content={"unknown_parameter": "@unknown:test"},
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
- def test_local_user_does_not_exist(self) -> None:
+ def test_local_user_does_not_exist(self):
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
+ body = json.dumps({"user_id": "@unknown:test"})
channel = self.make_request(
"POST",
self.url,
- content={"user_id": "@unknown:test"},
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- def test_remote_user(self) -> None:
+ def test_remote_user(self):
"""
Check that only local user can join rooms.
"""
+ body = json.dumps({"user_id": "@not:exist.bla"})
channel = self.make_request(
"POST",
self.url,
- content={"user_id": "@not:exist.bla"},
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"This endpoint can only be used with local users",
channel.json_body["error"],
)
- def test_room_does_not_exist(self) -> None:
+ def test_room_does_not_exist(self):
"""
- Check that unknown rooms/server return error HTTPStatus.NOT_FOUND.
+ Check that unknown rooms/server return error 404.
"""
+ body = json.dumps({"user_id": self.second_user_id})
url = "/_synapse/admin/v1/join/!unknown:test"
channel = self.make_request(
"POST",
url,
- content={"user_id": self.second_user_id},
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual("No known servers", channel.json_body["error"])
- def test_room_is_not_valid(self) -> None:
+ def test_room_is_not_valid(self):
"""
- Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
+ Check that invalid room names, return an error 400.
"""
+ body = json.dumps({"user_id": self.second_user_id})
url = "/_synapse/admin/v1/join/invalidroom"
channel = self.make_request(
"POST",
url,
- content={"user_id": self.second_user_id},
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom was not legal room ID or room alias",
channel.json_body["error"],
)
- def test_join_public_room(self) -> None:
+ def test_join_public_room(self):
"""
Test joining a local user to a public room with "JoinRules.PUBLIC"
"""
+ body = json.dumps({"user_id": self.second_user_id})
channel = self.make_request(
"POST",
self.url,
- content={"user_id": self.second_user_id},
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.public_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
@@ -1900,10 +1901,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.second_tok,
)
- self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEquals(200, channel.code, msg=channel.json_body)
self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
- def test_join_private_room_if_not_member(self) -> None:
+ def test_join_private_room_if_not_member(self):
"""
Test joining a local user to a private room with "JoinRules.INVITE"
when server admin is not member of this room.
@@ -1912,18 +1913,19 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
self.creator, tok=self.creator_tok, is_public=False
)
url = f"/_synapse/admin/v1/join/{private_room_id}"
+ body = json.dumps({"user_id": self.second_user_id})
channel = self.make_request(
"POST",
url,
- content={"user_id": self.second_user_id},
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_join_private_room_if_member(self) -> None:
+ def test_join_private_room_if_member(self):
"""
Test joining a local user to a private room with "JoinRules.INVITE",
when server admin is member of this room.
@@ -1948,20 +1950,21 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.admin_user_tok,
)
- self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEquals(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
# Join user to room.
url = f"/_synapse/admin/v1/join/{private_room_id}"
+ body = json.dumps({"user_id": self.second_user_id})
channel = self.make_request(
"POST",
url,
- content={"user_id": self.second_user_id},
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
@@ -1971,10 +1974,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.second_tok,
)
- self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEquals(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
- def test_join_private_room_if_owner(self) -> None:
+ def test_join_private_room_if_owner(self):
"""
Test joining a local user to a private room with "JoinRules.INVITE",
when server admin is owner of this room.
@@ -1983,15 +1986,16 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
self.admin_user, tok=self.admin_user_tok, is_public=False
)
url = f"/_synapse/admin/v1/join/{private_room_id}"
+ body = json.dumps({"user_id": self.second_user_id})
channel = self.make_request(
"POST",
url,
- content={"user_id": self.second_user_id},
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
@@ -2001,10 +2005,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.second_tok,
)
- self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEquals(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
- def test_context_as_non_admin(self) -> None:
+ def test_context_as_non_admin(self):
"""
Test that, without being admin, one cannot use the context admin API
"""
@@ -2035,10 +2039,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
% (room_id, events[midway]["event_id"]),
access_token=tok,
)
- self.assertEquals(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEquals(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_context_as_admin(self) -> None:
+ def test_context_as_admin(self):
"""
Test that, as admin, we can find the context of an event without having joined the room.
"""
@@ -2065,7 +2069,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
% (room_id, events[midway]["event_id"]),
access_token=self.admin_user_tok,
)
- self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEquals(200, channel.code, msg=channel.json_body)
self.assertEquals(
channel.json_body["event"]["event_id"], events[midway]["event_id"]
)
@@ -2094,7 +2098,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ def prepare(self, reactor, clock, homeserver):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -2111,7 +2115,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
self.public_room_id
)
- def test_public_room(self) -> None:
+ def test_public_room(self):
"""Test that getting admin in a public room works."""
room_id = self.helper.create_room_as(
self.creator, tok=self.creator_tok, is_public=True
@@ -2124,7 +2128,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room and ban a user.
self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok)
@@ -2136,7 +2140,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
tok=self.admin_user_tok,
)
- def test_private_room(self) -> None:
+ def test_private_room(self):
"""Test that getting admin in a private room works and we get invited."""
room_id = self.helper.create_room_as(
self.creator,
@@ -2151,7 +2155,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room (we should have received an
# invite) and can ban a user.
@@ -2164,7 +2168,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
tok=self.admin_user_tok,
)
- def test_other_user(self) -> None:
+ def test_other_user(self):
"""Test that giving admin in a public room works to a non-admin user works."""
room_id = self.helper.create_room_as(
self.creator, tok=self.creator_tok, is_public=True
@@ -2177,7 +2181,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room and ban a user.
self.helper.join(room_id, self.second_user_id, tok=self.second_tok)
@@ -2189,7 +2193,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
tok=self.second_tok,
)
- def test_not_enough_power(self) -> None:
+ def test_not_enough_power(self):
"""Test that we get a sensible error if there are no local room admins."""
room_id = self.helper.create_room_as(
self.creator, tok=self.creator_tok, is_public=True
@@ -2211,11 +2215,11 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- # We expect this to fail with a HTTPStatus.BAD_REQUEST as there are no room admins.
+ # We expect this to fail with a 400 as there are no room admins.
#
# (Note we assert the error message to ensure that it's not denied for
# some other reason)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
channel.json_body["error"],
"No local admin user in room with power to update power levels.",
@@ -2229,7 +2233,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ def prepare(self, reactor, clock, hs):
self._store = hs.get_datastore()
self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -2244,8 +2248,8 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
self.url = "/_synapse/admin/v1/rooms/%s/block"
@parameterized.expand([("PUT",), ("GET",)])
- def test_requester_is_no_admin(self, method: str) -> None:
- """If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned."""
+ def test_requester_is_no_admin(self, method: str):
+ """If the user is not a server admin, an error 403 is returned."""
channel = self.make_request(
method,
@@ -2258,8 +2262,8 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand([("PUT",), ("GET",)])
- def test_room_is_not_valid(self, method: str) -> None:
- """Check that invalid room names, return an error HTTPStatus.BAD_REQUEST."""
+ def test_room_is_not_valid(self, method: str):
+ """Check that invalid room names, return an error 400."""
channel = self.make_request(
method,
@@ -2274,7 +2278,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
channel.json_body["error"],
)
- def test_block_is_not_valid(self) -> None:
+ def test_block_is_not_valid(self):
"""If parameter `block` is not valid, return an error."""
# `block` is not valid
@@ -2309,7 +2313,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"])
- def test_block_room(self) -> None:
+ def test_block_room(self):
"""Test that block a room is successful."""
def _request_and_test_block_room(room_id: str) -> None:
@@ -2333,7 +2337,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
# unknown remote room
_request_and_test_block_room("!unknown:remote")
- def test_block_room_twice(self) -> None:
+ def test_block_room_twice(self):
"""Test that block a room that is already blocked is successful."""
self._is_blocked(self.room_id, expect=False)
@@ -2348,7 +2352,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
self.assertTrue(channel.json_body["block"])
self._is_blocked(self.room_id, expect=True)
- def test_unblock_room(self) -> None:
+ def test_unblock_room(self):
"""Test that unblock a room is successful."""
def _request_and_test_unblock_room(room_id: str) -> None:
@@ -2373,7 +2377,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
# unknown remote room
_request_and_test_unblock_room("!unknown:remote")
- def test_unblock_room_twice(self) -> None:
+ def test_unblock_room_twice(self):
"""Test that unblock a room that is not blocked is successful."""
self._block_room(self.room_id)
@@ -2388,7 +2392,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
self.assertFalse(channel.json_body["block"])
self._is_blocked(self.room_id, expect=False)
- def test_get_blocked_room(self) -> None:
+ def test_get_blocked_room(self):
"""Test get status of a blocked room"""
def _request_blocked_room(room_id: str) -> None:
@@ -2412,7 +2416,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
# unknown remote room
_request_blocked_room("!unknown:remote")
- def test_get_unblocked_room(self) -> None:
+ def test_get_unblocked_room(self):
"""Test get status of a unblocked room"""
def _request_unblocked_room(room_id: str) -> None:
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index 3c59f5f766..fbceba3254 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -11,18 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
-from typing import List
-from twisted.test.proto_helpers import MemoryReactor
+from typing import List
import synapse.rest.admin
from synapse.api.errors import Codes
from synapse.rest.client import login, room, sync
-from synapse.server import HomeServer
from synapse.storage.roommember import RoomsForUser
from synapse.types import JsonDict
-from synapse.util import Clock
from tests import unittest
from tests.unittest import override_config
@@ -37,7 +33,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
sync.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.room_shutdown_handler = hs.get_room_shutdown_handler()
self.pagination_handler = hs.get_pagination_handler()
@@ -52,18 +48,14 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
self.url = "/_synapse/admin/v1/send_server_notice"
- def test_no_auth(self) -> None:
+ def test_no_auth(self):
"""Try to send a server notice without authentication."""
channel = self.make_request("POST", self.url)
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_requester_is_no_admin(self) -> None:
+ def test_requester_is_no_admin(self):
"""If the user is not a server admin, an error is returned."""
channel = self.make_request(
"POST",
@@ -71,16 +63,12 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@override_config({"server_notices": {"system_mxid_localpart": "notices"}})
- def test_user_does_not_exist(self) -> None:
- """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND"""
+ def test_user_does_not_exist(self):
+ """Tests that a lookup for a user that does not exist returns a 404"""
channel = self.make_request(
"POST",
self.url,
@@ -88,13 +76,13 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
content={"user_id": "@unknown_person:test", "content": ""},
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@override_config({"server_notices": {"system_mxid_localpart": "notices"}})
- def test_user_is_not_local(self) -> None:
+ def test_user_is_not_local(self):
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
channel = self.make_request(
"POST",
@@ -106,13 +94,13 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"Server notices can only be sent to local users", channel.json_body["error"]
)
@override_config({"server_notices": {"system_mxid_localpart": "notices"}})
- def test_invalid_parameter(self) -> None:
+ def test_invalid_parameter(self):
"""If parameters are invalid, an error is returned."""
# no content, no user
@@ -122,7 +110,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"])
# no content
@@ -133,7 +121,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
content={"user_id": self.other_user},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
# no body
@@ -144,7 +132,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
content={"user_id": self.other_user, "content": ""},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual("'body' not in content", channel.json_body["error"])
@@ -156,11 +144,11 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
content={"user_id": self.other_user, "content": {"body": ""}},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual("'msgtype' not in content", channel.json_body["error"])
- def test_server_notice_disabled(self) -> None:
+ def test_server_notice_disabled(self):
"""Tests that server returns error if server notice is disabled"""
channel = self.make_request(
"POST",
@@ -172,14 +160,14 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual(
"Server notices are not enabled on this server", channel.json_body["error"]
)
@override_config({"server_notices": {"system_mxid_localpart": "notices"}})
- def test_send_server_notice(self) -> None:
+ def test_send_server_notice(self):
"""
Tests that sending two server notices is successfully,
the server uses the same room and do not send messages twice.
@@ -197,7 +185,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg one"},
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -228,7 +216,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg two"},
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# user has no new invites or memberships
self._check_invite_and_join_status(self.other_user, 0, 1)
@@ -243,7 +231,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
self.assertEqual(messages[1]["sender"], "@notices:test")
@override_config({"server_notices": {"system_mxid_localpart": "notices"}})
- def test_send_server_notice_leave_room(self) -> None:
+ def test_send_server_notice_leave_room(self):
"""
Tests that sending a server notices is successfully.
The user leaves the room and the second message appears
@@ -262,7 +250,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg one"},
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -305,7 +293,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg two"},
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -327,7 +315,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
self.assertNotEqual(first_room_id, second_room_id)
@override_config({"server_notices": {"system_mxid_localpart": "notices"}})
- def test_send_server_notice_delete_room(self) -> None:
+ def test_send_server_notice_delete_room(self):
"""
Tests that the user get server notice in a new room
after the first server notice room was deleted.
@@ -345,7 +333,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg one"},
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -394,7 +382,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg two"},
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -417,7 +405,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
def _check_invite_and_join_status(
self, user_id: str, expected_invites: int, expected_memberships: int
- ) -> List[RoomsForUser]:
+ ) -> RoomsForUser:
"""Check invite and room membership status of a user.
Args
@@ -452,7 +440,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET", "/_matrix/client/r0/sync", access_token=token
)
- self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertEqual(channel.code, 200)
# Get the messages
room = channel.json_body["rooms"]["join"][room_id]
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index 7cb8ec57ba..ece89a65ac 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -12,17 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
-from typing import List, Optional
-from twisted.test.proto_helpers import MemoryReactor
+import json
+from typing import Any, Dict, List, Optional
import synapse.rest.admin
from synapse.api.errors import Codes
from synapse.rest.client import login
-from synapse.server import HomeServer
-from synapse.types import JsonDict
-from synapse.util import Clock
from tests import unittest
from tests.test_utils import SMALL_PNG
@@ -34,7 +30,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ def prepare(self, reactor, clock, hs):
self.media_repo = hs.get_media_repository_resource()
self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -45,38 +41,30 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url = "/_synapse/admin/v1/statistics/users/media"
- def test_no_auth(self) -> None:
+ def test_no_auth(self):
"""
Try to list users without authentication.
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_requester_is_no_admin(self) -> None:
+ def test_requester_is_no_admin(self):
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
"GET",
self.url,
- {},
+ json.dumps({}),
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_invalid_parameter(self) -> None:
+ def test_invalid_parameter(self):
"""
If parameters are invalid, an error is returned.
"""
@@ -87,11 +75,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -101,11 +85,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative limit
@@ -115,11 +95,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from_ts
@@ -129,11 +105,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative until_ts
@@ -143,11 +115,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# until_ts smaller from_ts
@@ -157,11 +125,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# empty search term
@@ -171,11 +135,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
@@ -185,14 +145,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
- def test_limit(self) -> None:
+ def test_limit(self):
"""
Testing list of media with limit
"""
@@ -204,13 +160,13 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["users"]), 5)
self.assertEqual(channel.json_body["next_token"], 5)
self._check_fields(channel.json_body["users"])
- def test_from(self) -> None:
+ def test_from(self):
"""
Testing list of media with a defined starting point (from)
"""
@@ -222,13 +178,13 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["users"]), 15)
self.assertNotIn("next_token", channel.json_body)
self._check_fields(channel.json_body["users"])
- def test_limit_and_from(self) -> None:
+ def test_limit_and_from(self):
"""
Testing list of media with a defined starting point and limit
"""
@@ -240,13 +196,13 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(channel.json_body["next_token"], 15)
self.assertEqual(len(channel.json_body["users"]), 10)
self._check_fields(channel.json_body["users"])
- def test_next_token(self) -> None:
+ def test_next_token(self):
"""
Testing that `next_token` appears at the right place
"""
@@ -262,7 +218,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
@@ -275,7 +231,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
@@ -288,7 +244,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 19)
self.assertEqual(channel.json_body["next_token"], 19)
@@ -301,12 +257,12 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 1)
self.assertNotIn("next_token", channel.json_body)
- def test_no_media(self) -> None:
+ def test_no_media(self):
"""
Tests that a normal lookup for statistics is successfully
if users have no media created
@@ -318,11 +274,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["users"]))
- def test_order_by(self) -> None:
+ def test_order_by(self):
"""
Testing order list with parameter `order_by`
"""
@@ -400,7 +356,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
"b",
)
- def test_from_until_ts(self) -> None:
+ def test_from_until_ts(self):
"""
Testing filter by time with parameters `from_ts` and `until_ts`
"""
@@ -415,7 +371,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
# filter media starting at `ts1` after creating first media
@@ -425,7 +381,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?from_ts=%s" % (ts1,),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 0)
self._create_media(self.other_user_tok, 3)
@@ -440,7 +396,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
# filter media until `ts2` and earlier
@@ -449,10 +405,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?until_ts=%s" % (ts2,),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["users"][0]["media_count"], 6)
- def test_search_term(self) -> None:
+ def test_search_term(self):
self._create_users_with_media(20, 1)
# check without filter get all users
@@ -461,7 +417,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
# filter user 1 and 10-19 by `user_id`
@@ -470,7 +426,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?search_term=foo_user_1",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 11)
# filter on this user in `displayname`
@@ -479,7 +435,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?search_term=bar_user_10",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["users"][0]["displayname"], "bar_user_10")
self.assertEqual(channel.json_body["total"], 1)
@@ -489,10 +445,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?search_term=foobar",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 0)
- def _create_users_with_media(self, number_users: int, media_per_user: int) -> None:
+ def _create_users_with_media(self, number_users: int, media_per_user: int):
"""
Create a number of users with a number of media
Args:
@@ -504,7 +460,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
user_tok = self.login("foo_user_%s" % i, "pass")
self._create_media(user_tok, media_per_user)
- def _create_media(self, user_token: str, number_media: int) -> None:
+ def _create_media(self, user_token: str, number_media: int):
"""
Create a number of media for a specific user
Args:
@@ -515,10 +471,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
for _ in range(number_media):
# Upload some media into the room
self.helper.upload_media(
- upload_resource, SMALL_PNG, tok=user_token, expect_code=HTTPStatus.OK
+ upload_resource, SMALL_PNG, tok=user_token, expect_code=200
)
- def _check_fields(self, content: List[JsonDict]) -> None:
+ def _check_fields(self, content: List[Dict[str, Any]]):
"""Checks that all attributes are present in content
Args:
content: List that is checked for content
@@ -531,7 +487,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
def _order_test(
self, order_type: str, expected_user_list: List[str], dir: Optional[str] = None
- ) -> None:
+ ):
"""Request the list of users in a certain order. Assert that order is what
we expect
Args:
@@ -549,7 +505,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_user_list))
returned_order = [row["user_id"] for row in channel.json_body["users"]]
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 4fedd5fd08..5011e54563 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -17,7 +17,6 @@ import hmac
import os
import urllib.parse
from binascii import unhexlify
-from http import HTTPStatus
from typing import List, Optional
from unittest.mock import Mock, patch
@@ -75,7 +74,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"Shared secret registration is not enabled", channel.json_body["error"]
)
@@ -107,7 +106,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
body = {"nonce": nonce}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("username must be specified", channel.json_body["error"])
# 61 seconds
@@ -115,7 +114,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("unrecognised nonce", channel.json_body["error"])
def test_register_incorrect_nonce(self):
@@ -127,18 +126,18 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin")
- want_mac_str = want_mac.hexdigest()
+ want_mac = want_mac.hexdigest()
body = {
"nonce": nonce,
"username": "bob",
"password": "abc123",
"admin": True,
- "mac": want_mac_str,
+ "mac": want_mac,
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("HMAC incorrect", channel.json_body["error"])
def test_register_correct_nonce(self):
@@ -153,7 +152,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac.update(
nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support"
)
- want_mac_str = want_mac.hexdigest()
+ want_mac = want_mac.hexdigest()
body = {
"nonce": nonce,
@@ -161,11 +160,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"password": "abc123",
"admin": True,
"user_type": UserTypes.SUPPORT,
- "mac": want_mac_str,
+ "mac": want_mac,
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
def test_nonce_reuse(self):
@@ -177,24 +176,24 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin")
- want_mac_str = want_mac.hexdigest()
+ want_mac = want_mac.hexdigest()
body = {
"nonce": nonce,
"username": "bob",
"password": "abc123",
"admin": True,
- "mac": want_mac_str,
+ "mac": want_mac,
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
# Now, try and reuse it
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("unrecognised nonce", channel.json_body["error"])
def test_missing_parts(self):
@@ -215,7 +214,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be an empty body present
channel = self.make_request("POST", self.url, {})
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("nonce must be specified", channel.json_body["error"])
#
@@ -225,28 +224,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be present
channel = self.make_request("POST", self.url, {"nonce": nonce()})
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("username must be specified", channel.json_body["error"])
# Must be a string
body = {"nonce": nonce(), "username": 1234}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes
body = {"nonce": nonce(), "username": "abcd\u0000"}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes
body = {"nonce": nonce(), "username": "a" * 1000}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
#
@@ -257,28 +256,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
body = {"nonce": nonce(), "username": "a"}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("password must be specified", channel.json_body["error"])
# Must be a string
body = {"nonce": nonce(), "username": "a", "password": 1234}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
# Must not have null bytes
body = {"nonce": nonce(), "username": "a", "password": "abcd\u0000"}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
# Super long
body = {"nonce": nonce(), "username": "a", "password": "A" * 1000}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
#
@@ -294,7 +293,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid user type", channel.json_body["error"])
def test_displayname(self):
@@ -308,22 +307,22 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(nonce.encode("ascii") + b"\x00bob1\x00abc123\x00notadmin")
- want_mac_str = want_mac.hexdigest()
+ want_mac = want_mac.hexdigest()
body = {
"nonce": nonce,
"username": "bob1",
"password": "abc123",
- "mac": want_mac_str,
+ "mac": want_mac,
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob1:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob1:test/displayname")
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("bob1", channel.json_body["displayname"])
# displayname is None
@@ -332,22 +331,22 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(nonce.encode("ascii") + b"\x00bob2\x00abc123\x00notadmin")
- want_mac_str = want_mac.hexdigest()
+ want_mac = want_mac.hexdigest()
body = {
"nonce": nonce,
"username": "bob2",
"displayname": None,
"password": "abc123",
- "mac": want_mac_str,
+ "mac": want_mac,
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob2:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob2:test/displayname")
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("bob2", channel.json_body["displayname"])
# displayname is empty
@@ -356,22 +355,22 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(nonce.encode("ascii") + b"\x00bob3\x00abc123\x00notadmin")
- want_mac_str = want_mac.hexdigest()
+ want_mac = want_mac.hexdigest()
body = {
"nonce": nonce,
"username": "bob3",
"displayname": "",
"password": "abc123",
- "mac": want_mac_str,
+ "mac": want_mac,
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob3:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob3:test/displayname")
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
# set displayname
channel = self.make_request("GET", self.url)
@@ -379,22 +378,22 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(nonce.encode("ascii") + b"\x00bob4\x00abc123\x00notadmin")
- want_mac_str = want_mac.hexdigest()
+ want_mac = want_mac.hexdigest()
body = {
"nonce": nonce,
"username": "bob4",
"displayname": "Bob's Name",
"password": "abc123",
- "mac": want_mac_str,
+ "mac": want_mac,
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob4:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob4:test/displayname")
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("Bob's Name", channel.json_body["displayname"])
@override_config(
@@ -426,7 +425,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac.update(
nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support"
)
- want_mac_str = want_mac.hexdigest()
+ want_mac = want_mac.hexdigest()
body = {
"nonce": nonce,
@@ -434,11 +433,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"password": "abc123",
"admin": True,
"user_type": UserTypes.SUPPORT,
- "mac": want_mac_str,
+ "mac": want_mac,
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -462,7 +461,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self):
@@ -474,7 +473,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", self.url, access_token=other_user_token)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_all_users(self):
@@ -490,7 +489,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(3, len(channel.json_body["users"]))
self.assertEqual(3, channel.json_body["total"])
@@ -504,7 +503,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
expected_user_id: Optional[str],
search_term: str,
search_field: Optional[str] = "name",
- expected_http_code: Optional[int] = HTTPStatus.OK,
+ expected_http_code: Optional[int] = 200,
):
"""Search for a user and check that the returned user's id is a match
@@ -526,7 +525,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
- if expected_http_code != HTTPStatus.OK:
+ if expected_http_code != 200:
return
# Check that users were returned
@@ -587,7 +586,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -597,7 +596,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid guests
@@ -607,7 +606,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# invalid deactivated
@@ -617,7 +616,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# unkown order_by
@@ -627,7 +626,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# invalid search order
@@ -637,7 +636,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
def test_limit(self):
@@ -655,7 +654,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 5)
self.assertEqual(channel.json_body["next_token"], "5")
@@ -676,7 +675,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 15)
self.assertNotIn("next_token", channel.json_body)
@@ -697,7 +696,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(channel.json_body["next_token"], "15")
self.assertEqual(len(channel.json_body["users"]), 10)
@@ -720,7 +719,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
@@ -733,7 +732,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
@@ -746,7 +745,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 19)
self.assertEqual(channel.json_body["next_token"], "19")
@@ -760,7 +759,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 1)
self.assertNotIn("next_token", channel.json_body)
@@ -863,14 +862,14 @@ class UsersListTestCase(unittest.HomeserverTestCase):
url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_user_list))
returned_order = [row["name"] for row in channel.json_body["users"]]
self.assertEqual(expected_user_list, returned_order)
self._check_fields(channel.json_body["users"])
- def _check_fields(self, content: List[JsonDict]):
+ def _check_fields(self, content: JsonDict):
"""Checks that the expected user attributes are present in content
Args:
content: List that is checked for content
@@ -937,7 +936,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_not_admin(self):
@@ -948,7 +947,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", url, access_token=self.other_user_token)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
channel = self.make_request(
@@ -958,12 +957,12 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content=b"{}",
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
def test_user_does_not_exist(self):
"""
- Tests that deactivation for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that deactivation for a user that does not exist returns a 404
"""
channel = self.make_request(
@@ -972,7 +971,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_erase_is_not_bool(self):
@@ -987,18 +986,18 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_user_is_not_local(self):
"""
- Tests that deactivation for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that deactivation for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain"
channel = self.make_request("POST", url, access_token=self.admin_user_tok)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only deactivate local users", channel.json_body["error"])
def test_deactivate_user_erase_true(self):
@@ -1013,7 +1012,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
@@ -1028,7 +1027,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content={"erase": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Get user
channel = self.make_request(
@@ -1037,7 +1036,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
@@ -1058,7 +1057,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
@@ -1073,7 +1072,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content={"erase": False},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Get user
channel = self.make_request(
@@ -1082,7 +1081,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
@@ -1112,7 +1111,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
@@ -1127,7 +1126,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content={"erase": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Get user
channel = self.make_request(
@@ -1136,7 +1135,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
@@ -1196,7 +1195,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
channel = self.make_request(
@@ -1206,12 +1205,12 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=b"{}",
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
def test_user_does_not_exist(self):
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
channel = self.make_request(
@@ -1220,7 +1219,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"])
def test_invalid_parameter(self):
@@ -1235,7 +1234,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"admin": "not_bool"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
# deactivated not bool
@@ -1245,7 +1244,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": "not_bool"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# password not str
@@ -1255,7 +1254,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"password": True},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# password not length
@@ -1265,7 +1264,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"password": "x" * 513},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# user_type not valid
@@ -1275,7 +1274,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"user_type": "new type"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# external_ids not valid
@@ -1287,7 +1286,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"external_ids": {"auth_provider": "prov", "wrong_external_id": "id"}
},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
channel = self.make_request(
@@ -1296,7 +1295,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"external_ids": {"external_id": "id"}},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
# threepids not valid
@@ -1306,7 +1305,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"threepids": {"medium": "email", "wrong_address": "id"}},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
channel = self.make_request(
@@ -1315,7 +1314,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"threepids": {"address": "value"}},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
def test_get_user(self):
@@ -1328,7 +1327,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("User", channel.json_body["displayname"])
self._check_fields(channel.json_body)
@@ -1371,7 +1370,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1434,7 +1433,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1462,9 +1461,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# before limit of monthly active users is reached
channel = self.make_request("GET", "/sync", access_token=self.admin_user_tok)
- if channel.code != HTTPStatus.OK:
+ if channel.code != 200:
raise HttpResponseException(
- channel.code, channel.result["reason"], channel.json_body
+ channel.code, channel.result["reason"], channel.result["body"]
)
# Set monthly active users to the limit
@@ -1626,7 +1625,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "hahaha"},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self._check_fields(channel.json_body)
def test_set_displayname(self):
@@ -1642,7 +1641,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"displayname": "foobar"},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"])
@@ -1653,7 +1652,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"])
@@ -1675,7 +1674,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
# result does not always have the same sort order, therefore it becomes sorted
@@ -1701,7 +1700,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1717,7 +1716,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1733,7 +1732,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"threepids": []},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self._check_fields(channel.json_body)
@@ -1760,7 +1759,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1779,7 +1778,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1801,7 +1800,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
# other user has this two threepids
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
# result does not always have the same sort order, therefore it becomes sorted
@@ -1820,7 +1819,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
url_first_user,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self._check_fields(channel.json_body)
@@ -1849,7 +1848,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["external_ids"]))
# result does not always have the same sort order, therefore it becomes sorted
@@ -1881,7 +1880,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -1900,7 +1899,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -1919,7 +1918,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"external_ids": []},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(0, len(channel.json_body["external_ids"]))
@@ -1948,7 +1947,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -1974,7 +1973,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -2006,7 +2005,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
# must fail
- self.assertEqual(HTTPStatus.CONFLICT, channel.code, msg=channel.json_body)
+ self.assertEqual(409, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual("External id is already in use.", channel.json_body["error"])
@@ -2017,7 +2016,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -2035,7 +2034,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -2066,7 +2065,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
@@ -2081,7 +2080,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"deactivated": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
@@ -2097,7 +2096,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
@@ -2124,7 +2123,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"deactivated": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
@@ -2140,7 +2139,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"displayname": "Foobar"},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
self.assertEqual("Foobar", channel.json_body["displayname"])
@@ -2164,7 +2163,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
# Reactivate the user.
channel = self.make_request(
@@ -2173,7 +2172,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self.assertIsNotNone(channel.json_body["password_hash"])
@@ -2195,7 +2194,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"},
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Reactivate the user without a password.
@@ -2205,7 +2204,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
@@ -2227,7 +2226,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"},
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Reactivate the user without a password.
@@ -2237,7 +2236,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
@@ -2256,7 +2255,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"admin": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["admin"])
@@ -2267,7 +2266,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["admin"])
@@ -2284,7 +2283,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"user_type": UserTypes.SUPPORT},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
@@ -2295,7 +2294,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
@@ -2307,7 +2306,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"user_type": None},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertIsNone(channel.json_body["user_type"])
@@ -2318,7 +2317,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertIsNone(channel.json_body["user_type"])
@@ -2348,7 +2347,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"])
self.assertEqual(0, channel.json_body["deactivated"])
@@ -2361,7 +2360,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "abc123", "deactivated": "false"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
# Check user is not deactivated
channel = self.make_request(
@@ -2370,7 +2369,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"])
@@ -2395,7 +2394,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
self._is_erased(user_id, False)
@@ -2446,7 +2445,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self):
@@ -2461,7 +2460,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self):
@@ -2475,7 +2474,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["joined_rooms"]))
@@ -2491,7 +2490,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["joined_rooms"]))
@@ -2507,7 +2506,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["joined_rooms"]))
@@ -2528,7 +2527,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_rooms, channel.json_body["total"])
self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
@@ -2575,7 +2574,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"])
@@ -2604,7 +2603,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self):
@@ -2619,12 +2618,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self):
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/pushers"
channel = self.make_request(
@@ -2633,12 +2632,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self):
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers"
@@ -2648,7 +2647,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only look up local users", channel.json_body["error"])
def test_get_pushers(self):
@@ -2663,7 +2662,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
# Register the pusher
@@ -2694,7 +2693,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
for p in channel.json_body["pushers"]:
@@ -2733,7 +2732,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
"""Try to list media of an user without authentication."""
channel = self.make_request(method, self.url, {})
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "DELETE"])
@@ -2747,12 +2746,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "DELETE"])
def test_user_does_not_exist(self, method: str):
- """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND"""
+ """Tests that a lookup for a user that does not exist returns a 404"""
url = "/_synapse/admin/v1/users/@unknown_person:test/media"
channel = self.make_request(
method,
@@ -2760,12 +2759,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@parameterized.expand(["GET", "DELETE"])
def test_user_is_not_local(self, method: str):
- """Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST"""
+ """Tests that a lookup for a user that is not a local returns a 400"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media"
channel = self.make_request(
@@ -2774,7 +2773,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only look up local users", channel.json_body["error"])
def test_limit_GET(self):
@@ -2790,7 +2789,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 5)
self.assertEqual(channel.json_body["next_token"], 5)
@@ -2809,7 +2808,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 5)
self.assertEqual(len(channel.json_body["deleted_media"]), 5)
@@ -2826,7 +2825,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 15)
self.assertNotIn("next_token", channel.json_body)
@@ -2845,7 +2844,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 15)
self.assertEqual(len(channel.json_body["deleted_media"]), 15)
@@ -2862,7 +2861,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(channel.json_body["next_token"], 15)
self.assertEqual(len(channel.json_body["media"]), 10)
@@ -2881,7 +2880,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["deleted_media"]), 10)
@@ -2895,7 +2894,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# invalid search order
@@ -2905,7 +2904,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# negative limit
@@ -2915,7 +2914,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -2925,7 +2924,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_next_token(self):
@@ -2948,7 +2947,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), number_media)
self.assertNotIn("next_token", channel.json_body)
@@ -2961,7 +2960,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), number_media)
self.assertNotIn("next_token", channel.json_body)
@@ -2974,7 +2973,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 19)
self.assertEqual(channel.json_body["next_token"], 19)
@@ -2988,7 +2987,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 1)
self.assertNotIn("next_token", channel.json_body)
@@ -3005,7 +3004,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["media"]))
@@ -3020,7 +3019,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["deleted_media"]))
@@ -3037,7 +3036,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_media, channel.json_body["total"])
self.assertEqual(number_media, len(channel.json_body["media"]))
self.assertNotIn("next_token", channel.json_body)
@@ -3063,7 +3062,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_media, channel.json_body["total"])
self.assertEqual(number_media, len(channel.json_body["deleted_media"]))
self.assertCountEqual(channel.json_body["deleted_media"], media_ids)
@@ -3208,7 +3207,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource, image_data, user_token, filename, expect_code=HTTPStatus.OK
+ upload_resource, image_data, user_token, filename, expect_code=200
)
# Extract media ID from the response
@@ -3226,16 +3225,16 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- HTTPStatus.OK,
+ 200,
channel.code,
msg=(
- f"Expected to receive a HTTPStatus.OK on accessing media: {server_and_media_id}"
+ f"Expected to receive a 200 on accessing media: {server_and_media_id}"
),
)
return media_id
- def _check_fields(self, content: List[JsonDict]):
+ def _check_fields(self, content: JsonDict):
"""Checks that the expected user attributes are present in content
Args:
content: List that is checked for content
@@ -3275,7 +3274,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_media_list))
returned_order = [row["media_id"] for row in channel.json_body["media"]]
@@ -3311,14 +3310,14 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST", self.url, b"{}", access_token=self.admin_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
return channel.json_body["access_token"]
def test_no_auth(self):
"""Try to login as a user without authentication."""
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_not_admin(self):
@@ -3327,7 +3326,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
"POST", self.url, b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
def test_send_event(self):
"""Test that sending event as a user works."""
@@ -3352,7 +3351,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# We should only see the one device (from the login in `prepare`)
self.assertEqual(len(channel.json_body["devices"]), 1)
@@ -3364,21 +3363,21 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Test that we can successfully make a request
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Logout with the puppet token
channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# The puppet token should no longer work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
# .. but the real user's tokens should still work
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def test_user_logout_all(self):
"""Tests that the target user calling `/logout/all` does *not* expire
@@ -3389,23 +3388,23 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Test that we can successfully make a request
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Logout all with the real user token
channel = self.make_request(
"POST", "logout/all", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# The puppet token should still work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# .. but the real user's tokens shouldn't
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
def test_admin_logout_all(self):
"""Tests that the admin user calling `/logout/all` does expire the
@@ -3416,23 +3415,23 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Test that we can successfully make a request
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Logout all with the admin user token
channel = self.make_request(
"POST", "logout/all", b"{}", access_token=self.admin_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# The puppet token should no longer work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
# .. but the real user's tokens should still work
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
@unittest.override_config(
{
@@ -3460,10 +3459,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Now unaccept it and check that we can't send an event
self.get_success(self.store.user_set_consent_version(self.other_user, "0.0"))
self.helper.send_event(
- room_id,
- "com.example.test",
- tok=self.other_user_tok,
- expect_code=HTTPStatus.FORBIDDEN,
+ room_id, "com.example.test", tok=self.other_user_tok, expect_code=403
)
# Login in as the user
@@ -3481,10 +3477,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Trying to join as the other user should fail due to reaching MAU limit.
self.helper.join(
- room_id,
- user=self.other_user,
- tok=self.other_user_tok,
- expect_code=HTTPStatus.FORBIDDEN,
+ room_id, user=self.other_user, tok=self.other_user_tok, expect_code=403
)
# Logging in as the other user and joining a room should work, even
@@ -3519,7 +3512,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
Try to get information of an user without authentication.
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_not_admin(self):
@@ -3534,12 +3527,12 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.url,
access_token=other_user2_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_is_not_local(self):
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = self.url_prefix % "@unknown_person:unknown_domain"
@@ -3548,7 +3541,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only whois a local user", channel.json_body["error"])
def test_get_whois_admin(self):
@@ -3560,7 +3553,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body)
@@ -3575,7 +3568,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.url,
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body)
@@ -3605,7 +3598,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
Try to get information of an user without authentication.
"""
channel = self.make_request(method, self.url)
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["POST", "DELETE"])
@@ -3616,18 +3609,18 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
other_user_token = self.login("user", "pass")
channel = self.make_request(method, self.url, access_token=other_user_token)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["POST", "DELETE"])
def test_user_is_not_local(self, method: str):
"""
- Tests that shadow-banning for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that shadow-banning for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
channel = self.make_request(method, url, access_token=self.admin_user_tok)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
def test_success(self):
"""
@@ -3639,7 +3632,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
self.assertFalse(result.shadow_banned)
channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual({}, channel.json_body)
# Ensure the user is shadow-banned (and the cache was cleared).
@@ -3650,7 +3643,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE", self.url, access_token=self.admin_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual({}, channel.json_body)
# Ensure the user is no longer shadow-banned (and the cache was cleared).
@@ -3684,7 +3677,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request(method, self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "POST", "DELETE"])
@@ -3700,13 +3693,13 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "POST", "DELETE"])
def test_user_does_not_exist(self, method: str):
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit"
@@ -3716,7 +3709,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@parameterized.expand(
@@ -3728,7 +3721,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
)
def test_user_is_not_local(self, method: str, error_msg: str):
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = (
"/_synapse/admin/v1/users/@unknown_person:unknown_domain/override_ratelimit"
@@ -3740,7 +3733,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(error_msg, channel.json_body["error"])
def test_invalid_parameter(self):
@@ -3755,7 +3748,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"messages_per_second": "string"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# messages_per_second is negative
@@ -3766,7 +3759,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"messages_per_second": -1},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# burst_count is a string
@@ -3777,7 +3770,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"burst_count": "string"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# burst_count is negative
@@ -3788,7 +3781,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"burst_count": -1},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_return_zero_when_null(self):
@@ -3813,7 +3806,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["messages_per_second"])
self.assertEqual(0, channel.json_body["burst_count"])
@@ -3827,7 +3820,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body)
@@ -3838,7 +3831,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"messages_per_second": 10, "burst_count": 11},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(10, channel.json_body["messages_per_second"])
self.assertEqual(11, channel.json_body["burst_count"])
@@ -3849,7 +3842,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"messages_per_second": 20, "burst_count": 21},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(20, channel.json_body["messages_per_second"])
self.assertEqual(21, channel.json_body["burst_count"])
@@ -3859,7 +3852,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(20, channel.json_body["messages_per_second"])
self.assertEqual(21, channel.json_body["burst_count"])
@@ -3869,7 +3862,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body)
@@ -3879,6 +3872,6 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body)
diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py
index 7978626e71..4e1c49c28b 100644
--- a/tests/rest/admin/test_username_available.py
+++ b/tests/rest/admin/test_username_available.py
@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
-
import synapse.rest.admin
from synapse.api.errors import Codes, SynapseError
from synapse.rest.client import login
@@ -35,38 +33,30 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
async def check_username(username):
if username == "allowed":
return True
- raise SynapseError(
- HTTPStatus.BAD_REQUEST,
- "User ID already taken.",
- errcode=Codes.USER_IN_USE,
- )
+ raise SynapseError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
handler = self.hs.get_registration_handler()
handler.check_username = check_username
def test_username_available(self):
"""
- The endpoint should return a HTTPStatus.OK response if the username does not exist
+ The endpoint should return a 200 response if the username does not exist
"""
url = "%s?username=%s" % (self.url, "allowed")
channel = self.make_request("GET", url, None, self.admin_user_tok)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertTrue(channel.json_body["available"])
def test_username_unavailable(self):
"""
- The endpoint should return a HTTPStatus.OK response if the username does not exist
+ The endpoint should return a 200 response if the username does not exist
"""
url = "%s?username=%s" % (self.url, "disallowed")
channel = self.make_request("GET", url, None, self.admin_user_tok)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["errcode"], "M_USER_IN_USE")
self.assertEqual(channel.json_body["error"], "User ID already taken.")
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 72bbc87b4a..8552671431 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
from typing import Optional, Union
from twisted.internet.defer import succeed
@@ -514,39 +513,12 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass)
- def use_refresh_token(self, refresh_token: str) -> FakeChannel:
- """
- Helper that makes a request to use a refresh token.
- """
- return self.make_request(
- "POST",
- "/_matrix/client/v1/refresh",
- {"refresh_token": refresh_token},
- )
-
- def is_access_token_valid(self, access_token) -> bool:
- """
- Checks whether an access token is valid, returning whether it is or not.
- """
- code = self.make_request(
- "GET", "/_matrix/client/v3/account/whoami", access_token=access_token
- ).code
-
- # Either 200 or 401 is what we get back; anything else is a bug.
- assert code in {HTTPStatus.OK, HTTPStatus.UNAUTHORIZED}
-
- return code == HTTPStatus.OK
-
def test_login_issue_refresh_token(self):
"""
A login response should include a refresh_token only if asked.
"""
# Test login
- body = {
- "type": "m.login.password",
- "user": "test",
- "password": self.user_pass,
- }
+ body = {"type": "m.login.password", "user": "test", "password": self.user_pass}
login_without_refresh = self.make_request(
"POST", "/_matrix/client/r0/login", body
@@ -556,8 +528,8 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
login_with_refresh = self.make_request(
"POST",
- "/_matrix/client/r0/login",
- {"refresh_token": True, **body},
+ "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true",
+ body,
)
self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result)
self.assertIn("refresh_token", login_with_refresh.json_body)
@@ -583,12 +555,11 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
register_with_refresh = self.make_request(
"POST",
- "/_matrix/client/r0/register",
+ "/_matrix/client/r0/register?org.matrix.msc2918.refresh_token=true",
{
"username": "test3",
"password": self.user_pass,
"auth": {"type": LoginType.DUMMY},
- "refresh_token": True,
},
)
self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result)
@@ -599,22 +570,17 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
"""
A refresh token can be used to issue a new access token.
"""
- body = {
- "type": "m.login.password",
- "user": "test",
- "password": self.user_pass,
- "refresh_token": True,
- }
+ body = {"type": "m.login.password", "user": "test", "password": self.user_pass}
login_response = self.make_request(
"POST",
- "/_matrix/client/r0/login",
+ "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true",
body,
)
self.assertEqual(login_response.code, 200, login_response.result)
refresh_response = self.make_request(
"POST",
- "/_matrix/client/v1/refresh",
+ "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
{"refresh_token": login_response.json_body["refresh_token"]},
)
self.assertEqual(refresh_response.code, 200, refresh_response.result)
@@ -633,19 +599,14 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
)
@override_config({"refreshable_access_token_lifetime": "1m"})
- def test_refreshable_access_token_expiration(self):
+ def test_refresh_token_expiration(self):
"""
The access token should have some time as specified in the config.
"""
- body = {
- "type": "m.login.password",
- "user": "test",
- "password": self.user_pass,
- "refresh_token": True,
- }
+ body = {"type": "m.login.password", "user": "test", "password": self.user_pass}
login_response = self.make_request(
"POST",
- "/_matrix/client/r0/login",
+ "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true",
body,
)
self.assertEqual(login_response.code, 200, login_response.result)
@@ -655,198 +616,13 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
refresh_response = self.make_request(
"POST",
- "/_matrix/client/v1/refresh",
+ "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
{"refresh_token": login_response.json_body["refresh_token"]},
)
self.assertEqual(refresh_response.code, 200, refresh_response.result)
self.assertApproximates(
refresh_response.json_body["expires_in_ms"], 60 * 1000, 100
)
- access_token = refresh_response.json_body["access_token"]
-
- # Advance 59 seconds in the future (just shy of 1 minute, the time of expiry)
- self.reactor.advance(59.0)
- # Check that our token is valid
- self.assertEqual(
- self.make_request(
- "GET", "/_matrix/client/v3/account/whoami", access_token=access_token
- ).code,
- HTTPStatus.OK,
- )
-
- # Advance 2 more seconds (just past the time of expiry)
- self.reactor.advance(2.0)
- # Check that our token is invalid
- self.assertEqual(
- self.make_request(
- "GET", "/_matrix/client/v3/account/whoami", access_token=access_token
- ).code,
- HTTPStatus.UNAUTHORIZED,
- )
-
- @override_config(
- {
- "refreshable_access_token_lifetime": "1m",
- "nonrefreshable_access_token_lifetime": "10m",
- }
- )
- def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self):
- """
- Tests that the expiry times for refreshable and non-refreshable access
- tokens can be different.
- """
- body = {
- "type": "m.login.password",
- "user": "test",
- "password": self.user_pass,
- }
- login_response1 = self.make_request(
- "POST",
- "/_matrix/client/r0/login",
- {"refresh_token": True, **body},
- )
- self.assertEqual(login_response1.code, 200, login_response1.result)
- self.assertApproximates(
- login_response1.json_body["expires_in_ms"], 60 * 1000, 100
- )
- refreshable_access_token = login_response1.json_body["access_token"]
-
- login_response2 = self.make_request(
- "POST",
- "/_matrix/client/r0/login",
- body,
- )
- self.assertEqual(login_response2.code, 200, login_response2.result)
- nonrefreshable_access_token = login_response2.json_body["access_token"]
-
- # Advance 59 seconds in the future (just shy of 1 minute, the time of expiry)
- self.reactor.advance(59.0)
-
- # Both tokens should still be valid.
- self.assertTrue(self.is_access_token_valid(refreshable_access_token))
- self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
-
- # Advance to 61 s (just past 1 minute, the time of expiry)
- self.reactor.advance(2.0)
-
- # Only the non-refreshable token is still valid.
- self.assertFalse(self.is_access_token_valid(refreshable_access_token))
- self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
-
- # Advance to 599 s (just shy of 10 minutes, the time of expiry)
- self.reactor.advance(599.0 - 61.0)
-
- # It's still the case that only the non-refreshable token is still valid.
- self.assertFalse(self.is_access_token_valid(refreshable_access_token))
- self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
-
- # Advance to 601 s (just past 10 minutes, the time of expiry)
- self.reactor.advance(2.0)
-
- # Now neither token is valid.
- self.assertFalse(self.is_access_token_valid(refreshable_access_token))
- self.assertFalse(self.is_access_token_valid(nonrefreshable_access_token))
-
- @override_config(
- {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"}
- )
- def test_refresh_token_expiry(self):
- """
- The refresh token can be configured to have a limited lifetime.
- When that lifetime has ended, the refresh token can no longer be used to
- refresh the session.
- """
-
- body = {
- "type": "m.login.password",
- "user": "test",
- "password": self.user_pass,
- "refresh_token": True,
- }
- login_response = self.make_request(
- "POST",
- "/_matrix/client/r0/login",
- body,
- )
- self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
- refresh_token1 = login_response.json_body["refresh_token"]
-
- # Advance 119 seconds in the future (just shy of 2 minutes)
- self.reactor.advance(119.0)
-
- # Refresh our session. The refresh token should still JUST be valid right now.
- # By doing so, we get a new access token and a new refresh token.
- refresh_response = self.use_refresh_token(refresh_token1)
- self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result)
- self.assertIn(
- "refresh_token",
- refresh_response.json_body,
- "No new refresh token returned after refresh.",
- )
- refresh_token2 = refresh_response.json_body["refresh_token"]
-
- # Advance 121 seconds in the future (just a bit more than 2 minutes)
- self.reactor.advance(121.0)
-
- # Try to refresh our session, but instead notice that the refresh token is
- # not valid (it just expired).
- refresh_response = self.use_refresh_token(refresh_token2)
- self.assertEqual(
- refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result
- )
-
- @override_config(
- {
- "refreshable_access_token_lifetime": "2m",
- "refresh_token_lifetime": "2m",
- "session_lifetime": "3m",
- }
- )
- def test_ultimate_session_expiry(self):
- """
- The session can be configured to have an ultimate, limited lifetime.
- """
-
- body = {
- "type": "m.login.password",
- "user": "test",
- "password": self.user_pass,
- "refresh_token": True,
- }
- login_response = self.make_request(
- "POST",
- "/_matrix/client/r0/login",
- body,
- )
- self.assertEqual(login_response.code, 200, login_response.result)
- refresh_token = login_response.json_body["refresh_token"]
-
- # Advance shy of 2 minutes into the future
- self.reactor.advance(119.0)
-
- # Refresh our session. The refresh token should still be valid right now.
- refresh_response = self.use_refresh_token(refresh_token)
- self.assertEqual(refresh_response.code, 200, refresh_response.result)
- self.assertIn(
- "refresh_token",
- refresh_response.json_body,
- "No new refresh token returned after refresh.",
- )
- # Notice that our access token lifetime has been diminished to match the
- # session lifetime.
- # 3 minutes - 119 seconds = 61 seconds.
- self.assertEqual(refresh_response.json_body["expires_in_ms"], 61_000)
- refresh_token = refresh_response.json_body["refresh_token"]
-
- # Advance 61 seconds into the future. Our session should have expired
- # now, because we've had our 3 minutes.
- self.reactor.advance(61.0)
-
- # Try to issue a new, refreshed, access token.
- # This should fail because the refresh token's lifetime has also been
- # diminished as our session expired.
- refresh_response = self.use_refresh_token(refresh_token)
- self.assertEqual(refresh_response.code, 403, refresh_response.result)
def test_refresh_token_invalidation(self):
"""Refresh tokens are invalidated after first use of the next token.
@@ -864,15 +640,10 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|-> fourth_refresh (fails)
"""
- body = {
- "type": "m.login.password",
- "user": "test",
- "password": self.user_pass,
- "refresh_token": True,
- }
+ body = {"type": "m.login.password", "user": "test", "password": self.user_pass}
login_response = self.make_request(
"POST",
- "/_matrix/client/r0/login",
+ "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true",
body,
)
self.assertEqual(login_response.code, 200, login_response.result)
@@ -880,7 +651,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
# This first refresh should work properly
first_refresh_response = self.make_request(
"POST",
- "/_matrix/client/v1/refresh",
+ "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
{"refresh_token": login_response.json_body["refresh_token"]},
)
self.assertEqual(
@@ -890,7 +661,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
# This one as well, since the token in the first one was never used
second_refresh_response = self.make_request(
"POST",
- "/_matrix/client/v1/refresh",
+ "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
{"refresh_token": login_response.json_body["refresh_token"]},
)
self.assertEqual(
@@ -900,7 +671,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
# This one should not, since the token from the first refresh is not valid anymore
third_refresh_response = self.make_request(
"POST",
- "/_matrix/client/v1/refresh",
+ "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
{"refresh_token": first_refresh_response.json_body["refresh_token"]},
)
self.assertEqual(
@@ -928,7 +699,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
# Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail
fourth_refresh_response = self.make_request(
"POST",
- "/_matrix/client/v1/refresh",
+ "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
{"refresh_token": login_response.json_body["refresh_token"]},
)
self.assertEqual(
@@ -938,7 +709,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
# But refreshing from the last valid refresh token still works
fifth_refresh_response = self.make_request(
"POST",
- "/_matrix/client/v1/refresh",
+ "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
{"refresh_token": second_refresh_response.json_body["refresh_token"]},
)
self.assertEqual(
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 397c12c2a6..eb10d43217 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -19,7 +19,7 @@ from typing import Dict, List, Optional, Tuple
from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
-from synapse.rest.client import login, register, relations, room, sync
+from synapse.rest.client import login, register, relations, room
from tests import unittest
from tests.server import FakeChannel
@@ -29,7 +29,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
servlets = [
relations.register_servlets,
room.register_servlets,
- sync.register_servlets,
login.register_servlets,
register.register_servlets,
admin.register_servlets_for_client_rest_resource,
@@ -455,9 +454,11 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(400, channel.code, channel.json_body)
@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
- def test_bundled_aggregations(self):
- """Test that annotations, references, and threads get correctly bundled."""
- # Setup by sending a variety of relations.
+ def test_aggregation_get_event(self):
+ """Test that annotations, references, and threads get correctly bundled when
+ getting the parent event.
+ """
+
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(200, channel.code, channel.json_body)
@@ -484,169 +485,43 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
thread_2 = channel.json_body["event_id"]
- def assert_bundle(actual):
- """Assert the expected values of the bundled aggregations."""
-
- # Ensure the fields are as expected.
- self.assertCountEqual(
- actual.keys(),
- (
- RelationTypes.ANNOTATION,
- RelationTypes.REFERENCE,
- RelationTypes.THREAD,
- ),
- )
-
- # Check the values of each field.
- self.assertEquals(
- {
- "chunk": [
- {"type": "m.reaction", "key": "a", "count": 2},
- {"type": "m.reaction", "key": "b", "count": 1},
- ]
- },
- actual[RelationTypes.ANNOTATION],
- )
-
- self.assertEquals(
- {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
- actual[RelationTypes.REFERENCE],
- )
-
- self.assertEquals(
- 2,
- actual[RelationTypes.THREAD].get("count"),
- )
- # The latest thread event has some fields that don't matter.
- self.assert_dict(
- {
- "content": {
- "m.relates_to": {
- "event_id": self.parent_id,
- "rel_type": RelationTypes.THREAD,
- }
- },
- "event_id": thread_2,
- "room_id": self.room,
- "sender": self.user_id,
- "type": "m.room.test",
- "user_id": self.user_id,
- },
- actual[RelationTypes.THREAD].get("latest_event"),
- )
-
- def _find_and_assert_event(events):
- """
- Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
- """
- for event in events:
- if event["event_id"] == self.parent_id:
- break
- else:
- raise AssertionError(f"Event {self.parent_id} not found in chunk")
- assert_bundle(event["unsigned"].get("m.relations"))
-
- # Request the event directly.
channel = self.make_request(
"GET",
- f"/rooms/{self.room}/event/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEquals(200, channel.code, channel.json_body)
- assert_bundle(channel.json_body["unsigned"].get("m.relations"))
-
- # Request the room messages.
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/messages?dir=b",
- access_token=self.user_token,
- )
- self.assertEquals(200, channel.code, channel.json_body)
- _find_and_assert_event(channel.json_body["chunk"])
-
- # Request the room context.
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/context/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEquals(200, channel.code, channel.json_body)
- assert_bundle(channel.json_body["event"]["unsigned"].get("m.relations"))
-
- # Request sync.
- channel = self.make_request("GET", "/sync", access_token=self.user_token)
- self.assertEquals(200, channel.code, channel.json_body)
- room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
- self.assertTrue(room_timeline["limited"])
- _find_and_assert_event(room_timeline["events"])
-
- # Note that /relations is tested separately in test_aggregation_get_event_for_thread
- # since it needs different data configured.
-
- def test_aggregation_get_event_for_annotation(self):
- """Test that annotations do not get bundled aggregations included
- when directly requested.
- """
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEquals(200, channel.code, channel.json_body)
- annotation_id = channel.json_body["event_id"]
-
- # Annotate the annotation.
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id
- )
- self.assertEquals(200, channel.code, channel.json_body)
-
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/event/{annotation_id}",
+ "/rooms/%s/event/%s" % (self.room, self.parent_id),
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
- self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
- def test_aggregation_get_event_for_thread(self):
- """Test that threads get bundled aggregations included when directly requested."""
- channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- self.assertEquals(200, channel.code, channel.json_body)
- thread_id = channel.json_body["event_id"]
-
- # Annotate the annotation.
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
- )
- self.assertEquals(200, channel.code, channel.json_body)
-
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/event/{thread_id}",
- access_token=self.user_token,
- )
- self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(
channel.json_body["unsigned"].get("m.relations"),
{
RelationTypes.ANNOTATION: {
- "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
+ "chunk": [
+ {"type": "m.reaction", "key": "a", "count": 2},
+ {"type": "m.reaction", "key": "b", "count": 1},
+ ]
},
- },
- )
-
- # It should also be included when the entire thread is requested.
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
- access_token=self.user_token,
- )
- self.assertEquals(200, channel.code, channel.json_body)
- self.assertEqual(len(channel.json_body["chunk"]), 1)
-
- thread_message = channel.json_body["chunk"][0]
- self.assertEquals(
- thread_message["unsigned"].get("m.relations"),
- {
- RelationTypes.ANNOTATION: {
- "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
+ RelationTypes.REFERENCE: {
+ "chunk": [{"event_id": reply_1}, {"event_id": reply_2}]
+ },
+ RelationTypes.THREAD: {
+ "count": 2,
+ "latest_event": {
+ "age": 100,
+ "content": {
+ "m.relates_to": {
+ "event_id": self.parent_id,
+ "rel_type": RelationTypes.THREAD,
+ }
+ },
+ "event_id": thread_2,
+ "origin_server_ts": 1600,
+ "room_id": self.room,
+ "sender": self.user_id,
+ "type": "m.room.test",
+ "unsigned": {"age": 100},
+ "user_id": self.user_id,
+ },
},
},
)
@@ -797,56 +672,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
{"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
)
- def test_edit_edit(self):
- """Test that an edit cannot be edited."""
- new_body = {"msgtype": "m.text", "body": "Initial edit"}
- channel = self._send_relation(
- RelationTypes.REPLACE,
- "m.room.message",
- content={
- "msgtype": "m.text",
- "body": "Wibble",
- "m.new_content": new_body,
- },
- )
- self.assertEquals(200, channel.code, channel.json_body)
- edit_event_id = channel.json_body["event_id"]
-
- # Edit the edit event.
- channel = self._send_relation(
- RelationTypes.REPLACE,
- "m.room.message",
- content={
- "msgtype": "m.text",
- "body": "foo",
- "m.new_content": {"msgtype": "m.text", "body": "Ignored edit"},
- },
- parent_id=edit_event_id,
- )
- self.assertEquals(200, channel.code, channel.json_body)
-
- # Request the original event.
- channel = self.make_request(
- "GET",
- "/rooms/%s/event/%s" % (self.room, self.parent_id),
- access_token=self.user_token,
- )
- self.assertEquals(200, channel.code, channel.json_body)
- # The edit to the edit should be ignored.
- self.assertEquals(channel.json_body["content"], new_body)
-
- # The relations information should not include the edit to the edit.
- relations_dict = channel.json_body["unsigned"].get("m.relations")
- self.assertIn(RelationTypes.REPLACE, relations_dict)
-
- m_replace_dict = relations_dict[RelationTypes.REPLACE]
- for key in ["event_id", "sender", "origin_server_ts"]:
- self.assertIn(key, m_replace_dict)
-
- self.assert_dict(
- {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
- )
-
def test_relations_redaction_redacts_edits(self):
"""Test that edits of an event are redacted when the original event
is redacted.
diff --git a/tests/rest/media/v1/test_filepath.py b/tests/rest/media/v1/test_filepath.py
index 913bc530aa..8fe94f7d85 100644
--- a/tests/rest/media/v1/test_filepath.py
+++ b/tests/rest/media/v1/test_filepath.py
@@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
-import os
from typing import Iterable
-from synapse.rest.media.v1.filepath import MediaFilePaths, _wrap_with_jail_check
+from synapse.rest.media.v1.filepath import MediaFilePaths
from tests import unittest
@@ -487,109 +486,3 @@ class MediaFilePathsTestCase(unittest.TestCase):
f"{value!r} unexpectedly passed validation: "
f"{method} returned {path_or_list!r}"
)
-
-
-class MediaFilePathsJailTestCase(unittest.TestCase):
- def _check_relative_path(self, filepaths: MediaFilePaths, path: str) -> None:
- """Passes a relative path through the jail check.
-
- Args:
- filepaths: The `MediaFilePaths` instance.
- path: A path relative to the media store directory.
-
- Raises:
- ValueError: If the jail check fails.
- """
-
- @_wrap_with_jail_check(relative=True)
- def _make_relative_path(self: MediaFilePaths, path: str) -> str:
- return path
-
- _make_relative_path(filepaths, path)
-
- def _check_absolute_path(self, filepaths: MediaFilePaths, path: str) -> None:
- """Passes an absolute path through the jail check.
-
- Args:
- filepaths: The `MediaFilePaths` instance.
- path: A path relative to the media store directory.
-
- Raises:
- ValueError: If the jail check fails.
- """
-
- @_wrap_with_jail_check(relative=False)
- def _make_absolute_path(self: MediaFilePaths, path: str) -> str:
- return os.path.join(self.base_path, path)
-
- _make_absolute_path(filepaths, path)
-
- def test_traversal_inside(self) -> None:
- """Test the jail check for paths that stay within the media directory."""
- # Despite the `../`s, these paths still lie within the media directory and it's
- # expected for the jail check to allow them through.
- # These paths ought to trip the other checks in place and should never be
- # returned.
- filepaths = MediaFilePaths("/media_store")
- path = "url_cache/2020-01-02/../../GerZNDnDZVjsOtar"
- self._check_relative_path(filepaths, path)
- self._check_absolute_path(filepaths, path)
-
- def test_traversal_outside(self) -> None:
- """Test that the jail check fails for paths that escape the media directory."""
- filepaths = MediaFilePaths("/media_store")
- path = "url_cache/2020-01-02/../../../GerZNDnDZVjsOtar"
- with self.assertRaises(ValueError):
- self._check_relative_path(filepaths, path)
- with self.assertRaises(ValueError):
- self._check_absolute_path(filepaths, path)
-
- def test_traversal_reentry(self) -> None:
- """Test the jail check for paths that exit and re-enter the media directory."""
- # These paths lie outside the media directory if it is a symlink, and inside
- # otherwise. Ideally the check should fail, but this proves difficult.
- # This test documents the behaviour for this edge case.
- # These paths ought to trip the other checks in place and should never be
- # returned.
- filepaths = MediaFilePaths("/media_store")
- path = "url_cache/2020-01-02/../../../media_store/GerZNDnDZVjsOtar"
- self._check_relative_path(filepaths, path)
- self._check_absolute_path(filepaths, path)
-
- def test_symlink(self) -> None:
- """Test that a symlink does not cause the jail check to fail."""
- media_store_path = self.mktemp()
-
- # symlink the media store directory
- os.symlink("/mnt/synapse/media_store", media_store_path)
-
- # Test that relative and absolute paths don't trip the check
- # NB: `media_store_path` is a relative path
- filepaths = MediaFilePaths(media_store_path)
- self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar")
- self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar")
-
- filepaths = MediaFilePaths(os.path.abspath(media_store_path))
- self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar")
- self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar")
-
- def test_symlink_subdirectory(self) -> None:
- """Test that a symlinked subdirectory does not cause the jail check to fail."""
- media_store_path = self.mktemp()
- os.mkdir(media_store_path)
-
- # symlink `url_cache/`
- os.symlink(
- "/mnt/synapse/media_store_url_cache",
- os.path.join(media_store_path, "url_cache"),
- )
-
- # Test that relative and absolute paths don't trip the check
- # NB: `media_store_path` is a relative path
- filepaths = MediaFilePaths(media_store_path)
- self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar")
- self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar")
-
- filepaths = MediaFilePaths(os.path.abspath(media_store_path))
- self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar")
- self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar")
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 5ae491ff5a..a649e8c618 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -12,24 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
-from contextlib import contextmanager
-from typing import Generator
-from twisted.enterprise.adbapi import ConnectionPool
-from twisted.internet.defer import ensureDeferred
-from twisted.test.proto_helpers import MemoryReactor
-
-from synapse.api.room_versions import EventFormatVersions, RoomVersions
from synapse.logging.context import LoggingContext
from synapse.rest import admin
from synapse.rest.client import login, room
-from synapse.server import HomeServer
-from synapse.storage.databases.main.events_worker import (
- EVENT_QUEUE_THREADS,
- EventsWorkerStore,
-)
-from synapse.storage.types import Connection
-from synapse.util import Clock
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.util.async_helpers import yieldable_gather_results
from tests import unittest
@@ -157,127 +144,3 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# We should have fetched the event from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
-
-
-class DatabaseOutageTestCase(unittest.HomeserverTestCase):
- """Test event fetching during a database outage."""
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
- self.store: EventsWorkerStore = hs.get_datastore()
-
- self.room_id = f"!room:{hs.hostname}"
- self.event_ids = [f"event{i}" for i in range(20)]
-
- self._populate_events()
-
- def _populate_events(self) -> None:
- """Ensure that there are test events in the database.
-
- When testing with the in-memory SQLite database, all the events are lost during
- the simulated outage.
-
- To ensure consistency between `room_id`s and `event_id`s before and after the
- outage, rows are built and inserted manually.
-
- Upserts are used to handle the non-SQLite case where events are not lost.
- """
- self.get_success(
- self.store.db_pool.simple_upsert(
- "rooms",
- {"room_id": self.room_id},
- {"room_version": RoomVersions.V4.identifier},
- )
- )
-
- self.event_ids = [f"event{i}" for i in range(20)]
- for idx, event_id in enumerate(self.event_ids):
- self.get_success(
- self.store.db_pool.simple_upsert(
- "events",
- {"event_id": event_id},
- {
- "event_id": event_id,
- "room_id": self.room_id,
- "topological_ordering": idx,
- "stream_ordering": idx,
- "type": "test",
- "processed": True,
- "outlier": False,
- },
- )
- )
- self.get_success(
- self.store.db_pool.simple_upsert(
- "event_json",
- {"event_id": event_id},
- {
- "room_id": self.room_id,
- "json": json.dumps({"type": "test", "room_id": self.room_id}),
- "internal_metadata": "{}",
- "format_version": EventFormatVersions.V3,
- },
- )
- )
-
- @contextmanager
- def _outage(self) -> Generator[None, None, None]:
- """Simulate a database outage.
-
- Returns:
- A context manager. While the context is active, any attempts to connect to
- the database will fail.
- """
- connection_pool = self.store.db_pool._db_pool
-
- # Close all connections and shut down the database `ThreadPool`.
- connection_pool.close()
-
- # Restart the database `ThreadPool`.
- connection_pool.start()
-
- original_connection_factory = connection_pool.connectionFactory
-
- def connection_factory(_pool: ConnectionPool) -> Connection:
- raise Exception("Could not connect to the database.")
-
- connection_pool.connectionFactory = connection_factory # type: ignore[assignment]
- try:
- yield
- finally:
- connection_pool.connectionFactory = original_connection_factory
-
- # If the in-memory SQLite database is being used, all the events are gone.
- # Restore the test data.
- self._populate_events()
-
- def test_failure(self) -> None:
- """Test that event fetches do not get stuck during a database outage."""
- with self._outage():
- failure = self.get_failure(
- self.store.get_event(self.event_ids[0]), Exception
- )
- self.assertEqual(str(failure.value), "Could not connect to the database.")
-
- def test_recovery(self) -> None:
- """Test that event fetchers recover after a database outage."""
- with self._outage():
- # Kick off a bunch of event fetches but do not pump the reactor
- event_deferreds = []
- for event_id in self.event_ids:
- event_deferreds.append(ensureDeferred(self.store.get_event(event_id)))
-
- # We should have maxed out on event fetcher threads
- self.assertEqual(self.store._event_fetch_ongoing, EVENT_QUEUE_THREADS)
-
- # All the event fetchers will fail
- self.pump()
- self.assertEqual(self.store._event_fetch_ongoing, 0)
-
- for event_deferred in event_deferreds:
- failure = self.get_failure(event_deferred, Exception)
- self.assertEqual(
- str(failure.value), "Could not connect to the database."
- )
-
- # This next event fetch should succeed
- self.get_success(self.store.get_event(self.event_ids[0]))
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 329490caad..f26d5acf9c 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -14,37 +14,35 @@
import json
import os
import tempfile
-from typing import List, Optional, cast
from unittest.mock import Mock
import yaml
from twisted.internet import defer
-from twisted.test.proto_helpers import MemoryReactor
from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.config._base import ConfigError
-from synapse.events import EventBase
-from synapse.server import HomeServer
from synapse.storage.database import DatabasePool, make_conn
from synapse.storage.databases.main.appservice import (
ApplicationServiceStore,
ApplicationServiceTransactionStore,
)
-from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
+from tests.utils import setup_test_homeserver
-class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
+class ApplicationServiceStoreTestCase(unittest.TestCase):
+ @defer.inlineCallbacks
def setUp(self):
- super(ApplicationServiceStoreTestCase, self).setUp()
-
- self.as_yaml_files: List[str] = []
+ self.as_yaml_files = []
+ hs = yield setup_test_homeserver(
+ self.addCleanup, federation_sender=Mock(), federation_client=Mock()
+ )
- self.hs.config.appservice.app_service_config_files = self.as_yaml_files
- self.hs.config.caches.event_cache_size = 1
+ hs.config.appservice.app_service_config_files = self.as_yaml_files
+ hs.config.caches.event_cache_size = 1
self.as_token = "token1"
self.as_url = "some_url"
@@ -55,14 +53,12 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
# must be done after inserts
- database = self.hs.get_datastores().databases[0]
+ database = hs.get_datastores().databases[0]
self.store = ApplicationServiceStore(
- database,
- make_conn(database._database_config, database.engine, "test"),
- self.hs,
+ database, make_conn(database._database_config, database.engine, "test"), hs
)
- def tearDown(self) -> None:
+ def tearDown(self):
# TODO: suboptimal that we need to create files for tests!
for f in self.as_yaml_files:
try:
@@ -70,9 +66,7 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
except Exception:
pass
- super(ApplicationServiceStoreTestCase, self).tearDown()
-
- def _add_appservice(self, as_token, id, url, hs_token, sender) -> None:
+ def _add_appservice(self, as_token, id, url, hs_token, sender):
as_yaml = {
"url": url,
"as_token": as_token,
@@ -86,13 +80,12 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
outfile.write(yaml.dump(as_yaml))
self.as_yaml_files.append(as_token)
- def test_retrieve_unknown_service_token(self) -> None:
+ def test_retrieve_unknown_service_token(self):
service = self.store.get_app_service_by_token("invalid_token")
self.assertEquals(service, None)
- def test_retrieval_of_service(self) -> None:
+ def test_retrieval_of_service(self):
stored_service = self.store.get_app_service_by_token(self.as_token)
- assert stored_service is not None
self.assertEquals(stored_service.token, self.as_token)
self.assertEquals(stored_service.id, self.as_id)
self.assertEquals(stored_service.url, self.as_url)
@@ -100,18 +93,22 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
self.assertEquals(stored_service.namespaces[ApplicationService.NS_ROOMS], [])
self.assertEquals(stored_service.namespaces[ApplicationService.NS_USERS], [])
- def test_retrieval_of_all_services(self) -> None:
+ def test_retrieval_of_all_services(self):
services = self.store.get_app_services()
self.assertEquals(len(services), 3)
-class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
- def setUp(self) -> None:
- super(ApplicationServiceTransactionStoreTestCase, self).setUp()
- self.as_yaml_files: List[str] = []
+class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.as_yaml_files = []
+
+ hs = yield setup_test_homeserver(
+ self.addCleanup, federation_sender=Mock(), federation_client=Mock()
+ )
- self.hs.config.appservice.app_service_config_files = self.as_yaml_files
- self.hs.config.caches.event_cache_size = 1
+ hs.config.appservice.app_service_config_files = self.as_yaml_files
+ hs.config.caches.event_cache_size = 1
self.as_list = [
{"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
@@ -120,21 +117,21 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
{"token": "gamma_tok", "url": "https://gamma.com", "id": "id_gamma"},
]
for s in self.as_list:
- self._add_service(s["url"], s["token"], s["id"])
+ yield self._add_service(s["url"], s["token"], s["id"])
self.as_yaml_files = []
# We assume there is only one database in these tests
- database = self.hs.get_datastores().databases[0]
+ database = hs.get_datastores().databases[0]
self.db_pool = database._db_pool
self.engine = database.engine
- db_config = self.hs.config.database.get_single_database()
+ db_config = hs.config.database.get_single_database()
self.store = TestTransactionStore(
- database, make_conn(db_config, self.engine, "test"), self.hs
+ database, make_conn(db_config, self.engine, "test"), hs
)
- def _add_service(self, url, as_token, id) -> None:
+ def _add_service(self, url, as_token, id):
as_yaml = {
"url": url,
"as_token": as_token,
@@ -148,15 +145,13 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
outfile.write(yaml.dump(as_yaml))
self.as_yaml_files.append(as_token)
- def _set_state(
- self, id: str, state: ApplicationServiceState, txn: Optional[int] = None
- ):
+ def _set_state(self, id, state, txn=None):
return self.db_pool.runOperation(
self.engine.convert_param_style(
"INSERT INTO application_services_state(as_id, state, last_txn) "
"VALUES(?,?,?)"
),
- (id, state.value, txn),
+ (id, state, txn),
)
def _insert_txn(self, as_id, txn_id, events):
@@ -174,277 +169,234 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
"INSERT INTO application_services_state(as_id, last_txn, state) "
"VALUES(?,?,?)"
),
- (as_id, txn_id, ApplicationServiceState.UP.value),
+ (as_id, txn_id, ApplicationServiceState.UP),
)
- def test_get_appservice_state_none(
- self,
- ) -> None:
+ @defer.inlineCallbacks
+ def test_get_appservice_state_none(self):
service = Mock(id="999")
- state = self.get_success(self.store.get_appservice_state(service))
+ state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
self.assertEquals(None, state)
- def test_get_appservice_state_up(
- self,
- ) -> None:
- self.get_success(
- self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
- )
+ @defer.inlineCallbacks
+ def test_get_appservice_state_up(self):
+ yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
service = Mock(id=self.as_list[0]["id"])
- state = self.get_success(
- defer.ensureDeferred(self.store.get_appservice_state(service))
- )
+ state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
self.assertEquals(ApplicationServiceState.UP, state)
- def test_get_appservice_state_down(
- self,
- ) -> None:
- self.get_success(
- self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
- )
- self.get_success(
- self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN)
- )
- self.get_success(
- self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
- )
+ @defer.inlineCallbacks
+ def test_get_appservice_state_down(self):
+ yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
+ yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN)
+ yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
service = Mock(id=self.as_list[1]["id"])
- state = self.get_success(self.store.get_appservice_state(service))
+ state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
self.assertEquals(ApplicationServiceState.DOWN, state)
- def test_get_appservices_by_state_none(
- self,
- ) -> None:
- services = self.get_success(
+ @defer.inlineCallbacks
+ def test_get_appservices_by_state_none(self):
+ services = yield defer.ensureDeferred(
self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
)
self.assertEquals(0, len(services))
- def test_set_appservices_state_down(
- self,
- ) -> None:
+ @defer.inlineCallbacks
+ def test_set_appservices_state_down(self):
service = Mock(id=self.as_list[1]["id"])
- self.get_success(
+ yield defer.ensureDeferred(
self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
)
- rows = self.get_success(
- self.db_pool.runQuery(
- self.engine.convert_param_style(
- "SELECT as_id FROM application_services_state WHERE state=?"
- ),
- (ApplicationServiceState.DOWN.value,),
- )
+ rows = yield self.db_pool.runQuery(
+ self.engine.convert_param_style(
+ "SELECT as_id FROM application_services_state WHERE state=?"
+ ),
+ (ApplicationServiceState.DOWN,),
)
self.assertEquals(service.id, rows[0][0])
- def test_set_appservices_state_multiple_up(
- self,
- ) -> None:
+ @defer.inlineCallbacks
+ def test_set_appservices_state_multiple_up(self):
service = Mock(id=self.as_list[1]["id"])
- self.get_success(
+ yield defer.ensureDeferred(
self.store.set_appservice_state(service, ApplicationServiceState.UP)
)
- self.get_success(
+ yield defer.ensureDeferred(
self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
)
- self.get_success(
+ yield defer.ensureDeferred(
self.store.set_appservice_state(service, ApplicationServiceState.UP)
)
- rows = self.get_success(
- self.db_pool.runQuery(
- self.engine.convert_param_style(
- "SELECT as_id FROM application_services_state WHERE state=?"
- ),
- (ApplicationServiceState.UP.value,),
- )
+ rows = yield self.db_pool.runQuery(
+ self.engine.convert_param_style(
+ "SELECT as_id FROM application_services_state WHERE state=?"
+ ),
+ (ApplicationServiceState.UP,),
)
self.assertEquals(service.id, rows[0][0])
- def test_create_appservice_txn_first(
- self,
- ) -> None:
+ @defer.inlineCallbacks
+ def test_create_appservice_txn_first(self):
service = Mock(id=self.as_list[0]["id"])
- events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
- txn = self.get_success(
- defer.ensureDeferred(self.store.create_appservice_txn(service, events, []))
+ events = [Mock(event_id="e1"), Mock(event_id="e2")]
+ txn = yield defer.ensureDeferred(
+ self.store.create_appservice_txn(service, events, [])
)
self.assertEquals(txn.id, 1)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
- def test_create_appservice_txn_older_last_txn(
- self,
- ) -> None:
+ @defer.inlineCallbacks
+ def test_create_appservice_txn_older_last_txn(self):
service = Mock(id=self.as_list[0]["id"])
- events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
- self.get_success(self._set_last_txn(service.id, 9643)) # AS is falling behind
- self.get_success(self._insert_txn(service.id, 9644, events))
- self.get_success(self._insert_txn(service.id, 9645, events))
- txn = self.get_success(self.store.create_appservice_txn(service, events, []))
+ events = [Mock(event_id="e1"), Mock(event_id="e2")]
+ yield self._set_last_txn(service.id, 9643) # AS is falling behind
+ yield self._insert_txn(service.id, 9644, events)
+ yield self._insert_txn(service.id, 9645, events)
+ txn = yield defer.ensureDeferred(
+ self.store.create_appservice_txn(service, events, [])
+ )
self.assertEquals(txn.id, 9646)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
- def test_create_appservice_txn_up_to_date_last_txn(
- self,
- ) -> None:
+ @defer.inlineCallbacks
+ def test_create_appservice_txn_up_to_date_last_txn(self):
service = Mock(id=self.as_list[0]["id"])
- events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
- self.get_success(self._set_last_txn(service.id, 9643))
- txn = self.get_success(self.store.create_appservice_txn(service, events, []))
+ events = [Mock(event_id="e1"), Mock(event_id="e2")]
+ yield self._set_last_txn(service.id, 9643)
+ txn = yield defer.ensureDeferred(
+ self.store.create_appservice_txn(service, events, [])
+ )
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
- def test_create_appservice_txn_up_fuzzing(
- self,
- ) -> None:
+ @defer.inlineCallbacks
+ def test_create_appservice_txn_up_fuzzing(self):
service = Mock(id=self.as_list[0]["id"])
- events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
- self.get_success(self._set_last_txn(service.id, 9643))
+ events = [Mock(event_id="e1"), Mock(event_id="e2")]
+ yield self._set_last_txn(service.id, 9643)
# dump in rows with higher IDs to make sure the queries aren't wrong.
- self.get_success(self._set_last_txn(self.as_list[1]["id"], 119643))
- self.get_success(self._set_last_txn(self.as_list[2]["id"], 9))
- self.get_success(self._set_last_txn(self.as_list[3]["id"], 9643))
- self.get_success(self._insert_txn(self.as_list[1]["id"], 119644, events))
- self.get_success(self._insert_txn(self.as_list[1]["id"], 119645, events))
- self.get_success(self._insert_txn(self.as_list[1]["id"], 119646, events))
- self.get_success(self._insert_txn(self.as_list[2]["id"], 10, events))
- self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events))
-
- txn = self.get_success(self.store.create_appservice_txn(service, events, []))
+ yield self._set_last_txn(self.as_list[1]["id"], 119643)
+ yield self._set_last_txn(self.as_list[2]["id"], 9)
+ yield self._set_last_txn(self.as_list[3]["id"], 9643)
+ yield self._insert_txn(self.as_list[1]["id"], 119644, events)
+ yield self._insert_txn(self.as_list[1]["id"], 119645, events)
+ yield self._insert_txn(self.as_list[1]["id"], 119646, events)
+ yield self._insert_txn(self.as_list[2]["id"], 10, events)
+ yield self._insert_txn(self.as_list[3]["id"], 9643, events)
+
+ txn = yield defer.ensureDeferred(
+ self.store.create_appservice_txn(service, events, [])
+ )
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
- def test_complete_appservice_txn_first_txn(
- self,
- ) -> None:
+ @defer.inlineCallbacks
+ def test_complete_appservice_txn_first_txn(self):
service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")]
txn_id = 1
- self.get_success(self._insert_txn(service.id, txn_id, events))
- self.get_success(
+ yield self._insert_txn(service.id, txn_id, events)
+ yield defer.ensureDeferred(
self.store.complete_appservice_txn(txn_id=txn_id, service=service)
)
- res = self.get_success(
- self.db_pool.runQuery(
- self.engine.convert_param_style(
- "SELECT last_txn FROM application_services_state WHERE as_id=?"
- ),
- (service.id,),
- )
+ res = yield self.db_pool.runQuery(
+ self.engine.convert_param_style(
+ "SELECT last_txn FROM application_services_state WHERE as_id=?"
+ ),
+ (service.id,),
)
self.assertEquals(1, len(res))
self.assertEquals(txn_id, res[0][0])
- res = self.get_success(
- self.db_pool.runQuery(
- self.engine.convert_param_style(
- "SELECT * FROM application_services_txns WHERE txn_id=?"
- ),
- (txn_id,),
- )
+ res = yield self.db_pool.runQuery(
+ self.engine.convert_param_style(
+ "SELECT * FROM application_services_txns WHERE txn_id=?"
+ ),
+ (txn_id,),
)
self.assertEquals(0, len(res))
- def test_complete_appservice_txn_existing_in_state_table(
- self,
- ) -> None:
+ @defer.inlineCallbacks
+ def test_complete_appservice_txn_existing_in_state_table(self):
service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")]
txn_id = 5
- self.get_success(self._set_last_txn(service.id, 4))
- self.get_success(self._insert_txn(service.id, txn_id, events))
- self.get_success(
+ yield self._set_last_txn(service.id, 4)
+ yield self._insert_txn(service.id, txn_id, events)
+ yield defer.ensureDeferred(
self.store.complete_appservice_txn(txn_id=txn_id, service=service)
)
- res = self.get_success(
- self.db_pool.runQuery(
- self.engine.convert_param_style(
- "SELECT last_txn, state FROM application_services_state WHERE as_id=?"
- ),
- (service.id,),
- )
+ res = yield self.db_pool.runQuery(
+ self.engine.convert_param_style(
+ "SELECT last_txn, state FROM application_services_state WHERE as_id=?"
+ ),
+ (service.id,),
)
self.assertEquals(1, len(res))
self.assertEquals(txn_id, res[0][0])
- self.assertEquals(ApplicationServiceState.UP.value, res[0][1])
-
- res = self.get_success(
- self.db_pool.runQuery(
- self.engine.convert_param_style(
- "SELECT * FROM application_services_txns WHERE txn_id=?"
- ),
- (txn_id,),
- )
+ self.assertEquals(ApplicationServiceState.UP, res[0][1])
+
+ res = yield self.db_pool.runQuery(
+ self.engine.convert_param_style(
+ "SELECT * FROM application_services_txns WHERE txn_id=?"
+ ),
+ (txn_id,),
)
self.assertEquals(0, len(res))
- def test_get_oldest_unsent_txn_none(
- self,
- ) -> None:
+ @defer.inlineCallbacks
+ def test_get_oldest_unsent_txn_none(self):
service = Mock(id=self.as_list[0]["id"])
- txn = self.get_success(self.store.get_oldest_unsent_txn(service))
+ txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service))
self.assertEquals(None, txn)
- def test_get_oldest_unsent_txn(self) -> None:
+ @defer.inlineCallbacks
+ def test_get_oldest_unsent_txn(self):
service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")]
other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
# we aren't testing store._base stuff here, so mock this out
- # (ignore needed because Mypy won't allow us to assign to a method otherwise)
- self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) # type: ignore[assignment]
+ self.store.get_events_as_list = Mock(return_value=make_awaitable(events))
- self.get_success(self._insert_txn(self.as_list[1]["id"], 9, other_events))
- self.get_success(self._insert_txn(service.id, 10, events))
- self.get_success(self._insert_txn(service.id, 11, other_events))
- self.get_success(self._insert_txn(service.id, 12, other_events))
+ yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
+ yield self._insert_txn(service.id, 10, events)
+ yield self._insert_txn(service.id, 11, other_events)
+ yield self._insert_txn(service.id, 12, other_events)
- txn = self.get_success(self.store.get_oldest_unsent_txn(service))
+ txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service))
self.assertEquals(service, txn.service)
self.assertEquals(10, txn.id)
self.assertEquals(events, txn.events)
- def test_get_appservices_by_state_single(
- self,
- ) -> None:
- self.get_success(
- self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN)
- )
- self.get_success(
- self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP)
- )
+ @defer.inlineCallbacks
+ def test_get_appservices_by_state_single(self):
+ yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN)
+ yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP)
- services = self.get_success(
+ services = yield defer.ensureDeferred(
self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
)
self.assertEquals(1, len(services))
self.assertEquals(self.as_list[0]["id"], services[0].id)
- def test_get_appservices_by_state_multiple(
- self,
- ) -> None:
- self.get_success(
- self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN)
- )
- self.get_success(
- self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP)
- )
- self.get_success(
- self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
- )
- self.get_success(
- self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP)
- )
+ @defer.inlineCallbacks
+ def test_get_appservices_by_state_multiple(self):
+ yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN)
+ yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP)
+ yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
+ yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP)
- services = self.get_success(
+ services = yield defer.ensureDeferred(
self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
)
self.assertEquals(2, len(services))
@@ -455,16 +407,16 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
- def prepare(
- self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
- ) -> None:
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver()
+ return hs
+
+ def prepare(self, hs, reactor, clock):
self.service = Mock(id="foo")
self.store = self.hs.get_datastore()
- self.get_success(
- self.store.set_appservice_state(self.service, ApplicationServiceState.UP)
- )
+ self.get_success(self.store.set_appservice_state(self.service, "up"))
- def test_get_type_stream_id_for_appservice_no_value(self) -> None:
+ def test_get_type_stream_id_for_appservice_no_value(self):
value = self.get_success(
self.store.get_type_stream_id_for_appservice(self.service, "read_receipt")
)
@@ -475,13 +427,13 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
)
self.assertEquals(value, 0)
- def test_get_type_stream_id_for_appservice_invalid_type(self) -> None:
+ def test_get_type_stream_id_for_appservice_invalid_type(self):
self.get_failure(
self.store.get_type_stream_id_for_appservice(self.service, "foobar"),
ValueError,
)
- def test_set_type_stream_id_for_appservice(self) -> None:
+ def test_set_type_stream_id_for_appservice(self):
read_receipt_value = 1024
self.get_success(
self.store.set_type_stream_id_for_appservice(
@@ -503,7 +455,7 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
)
self.assertEqual(result, read_receipt_value)
- def test_set_type_stream_id_for_appservice_invalid_type(self) -> None:
+ def test_set_type_stream_id_for_appservice_invalid_type(self):
self.get_failure(
self.store.set_type_stream_id_for_appservice(self.service, "foobar", 1024),
ValueError,
@@ -512,12 +464,12 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
# required for ApplicationServiceTransactionStoreTestCase tests
class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
- def __init__(self, database: DatabasePool, db_conn, hs) -> None:
+ def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
-class ApplicationServiceStoreConfigTestCase(unittest.HomeserverTestCase):
- def _write_config(self, suffix, **kwargs) -> str:
+class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
+ def _write_config(self, suffix, **kwargs):
vals = {
"id": "id" + suffix,
"url": "url" + suffix,
@@ -533,33 +485,41 @@ class ApplicationServiceStoreConfigTestCase(unittest.HomeserverTestCase):
f.write(yaml.dump(vals))
return path
- def test_unique_works(self) -> None:
+ @defer.inlineCallbacks
+ def test_unique_works(self):
f1 = self._write_config(suffix="1")
f2 = self._write_config(suffix="2")
- self.hs.config.appservice.app_service_config_files = [f1, f2]
- self.hs.config.caches.event_cache_size = 1
+ hs = yield setup_test_homeserver(
+ self.addCleanup, federation_sender=Mock(), federation_client=Mock()
+ )
+
+ hs.config.appservice.app_service_config_files = [f1, f2]
+ hs.config.caches.event_cache_size = 1
- database = self.hs.get_datastores().databases[0]
+ database = hs.get_datastores().databases[0]
ApplicationServiceStore(
- database,
- make_conn(database._database_config, database.engine, "test"),
- self.hs,
+ database, make_conn(database._database_config, database.engine, "test"), hs
)
- def test_duplicate_ids(self) -> None:
+ @defer.inlineCallbacks
+ def test_duplicate_ids(self):
f1 = self._write_config(id="id", suffix="1")
f2 = self._write_config(id="id", suffix="2")
- self.hs.config.appservice.app_service_config_files = [f1, f2]
- self.hs.config.caches.event_cache_size = 1
+ hs = yield setup_test_homeserver(
+ self.addCleanup, federation_sender=Mock(), federation_client=Mock()
+ )
+
+ hs.config.appservice.app_service_config_files = [f1, f2]
+ hs.config.caches.event_cache_size = 1
with self.assertRaises(ConfigError) as cm:
- database = self.hs.get_datastores().databases[0]
+ database = hs.get_datastores().databases[0]
ApplicationServiceStore(
database,
make_conn(database._database_config, database.engine, "test"),
- self.hs,
+ hs,
)
e = cm.exception
@@ -567,19 +527,24 @@ class ApplicationServiceStoreConfigTestCase(unittest.HomeserverTestCase):
self.assertIn(f2, str(e))
self.assertIn("id", str(e))
- def test_duplicate_as_tokens(self) -> None:
+ @defer.inlineCallbacks
+ def test_duplicate_as_tokens(self):
f1 = self._write_config(as_token="as_token", suffix="1")
f2 = self._write_config(as_token="as_token", suffix="2")
- self.hs.config.appservice.app_service_config_files = [f1, f2]
- self.hs.config.caches.event_cache_size = 1
+ hs = yield setup_test_homeserver(
+ self.addCleanup, federation_sender=Mock(), federation_client=Mock()
+ )
+
+ hs.config.appservice.app_service_config_files = [f1, f2]
+ hs.config.caches.event_cache_size = 1
with self.assertRaises(ConfigError) as cm:
- database = self.hs.get_datastores().databases[0]
+ database = hs.get_datastores().databases[0]
ApplicationServiceStore(
database,
make_conn(database._database_config, database.engine, "test"),
- self.hs,
+ hs,
)
e = cm.exception
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index d77c001506..a5f5ebad41 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -1,26 +1,8 @@
-# Copyright 2021 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Use backported mock for AsyncMock support on Python 3.6.
-from mock import Mock
-
-from twisted.internet.defer import Deferred, ensureDeferred
+from unittest.mock import Mock
from synapse.storage.background_updates import BackgroundUpdater
from tests import unittest
-from tests.test_utils import make_awaitable
class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
@@ -38,10 +20,10 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
def test_do_background_update(self):
# the time we claim it takes to update one item when running the update
- duration_ms = 10
+ duration_ms = 4200
# the target runtime for each bg update
- target_background_update_duration_ms = 100
+ target_background_update_duration_ms = 5000000
store = self.hs.get_datastore()
self.get_success(
@@ -66,8 +48,10 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.update_handler.side_effect = update
self.update_handler.reset_mock()
res = self.get_success(
- self.updates.do_next_background_update(False),
- by=0.01,
+ self.updates.do_next_background_update(
+ target_background_update_duration_ms
+ ),
+ by=0.1,
)
self.assertFalse(res)
@@ -90,93 +74,16 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.update_handler.side_effect = update
self.update_handler.reset_mock()
- result = self.get_success(self.updates.do_next_background_update(False))
+ result = self.get_success(
+ self.updates.do_next_background_update(target_background_update_duration_ms)
+ )
self.assertFalse(result)
self.update_handler.assert_called_once()
# third step: we don't expect to be called any more
self.update_handler.reset_mock()
- result = self.get_success(self.updates.do_next_background_update(False))
+ result = self.get_success(
+ self.updates.do_next_background_update(target_background_update_duration_ms)
+ )
self.assertTrue(result)
self.assertFalse(self.update_handler.called)
-
-
-class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, homeserver):
- self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates
- # the base test class should have run the real bg updates for us
- self.assertTrue(
- self.get_success(self.updates.has_completed_background_updates())
- )
-
- self.update_deferred = Deferred()
- self.update_handler = Mock(return_value=self.update_deferred)
- self.updates.register_background_update_handler(
- "test_update", self.update_handler
- )
-
- # Mock out the AsyncContextManager
- self._update_ctx_manager = Mock(spec=["__aenter__", "__aexit__"])
- self._update_ctx_manager.__aenter__ = Mock(
- return_value=make_awaitable(None),
- )
- self._update_ctx_manager.__aexit__ = Mock(return_value=make_awaitable(None))
-
- # Mock out the `update_handler` callback
- self._on_update = Mock(return_value=self._update_ctx_manager)
-
- # Define a default batch size value that's not the same as the internal default
- # value (100).
- self._default_batch_size = 500
-
- # Register the callbacks with more mocks
- self.hs.get_module_api().register_background_update_controller_callbacks(
- on_update=self._on_update,
- min_batch_size=Mock(return_value=make_awaitable(self._default_batch_size)),
- default_batch_size=Mock(
- return_value=make_awaitable(self._default_batch_size),
- ),
- )
-
- def test_controller(self):
- store = self.hs.get_datastore()
- self.get_success(
- store.db_pool.simple_insert(
- "background_updates",
- values={"update_name": "test_update", "progress_json": "{}"},
- )
- )
-
- # Set the return value for the context manager.
- enter_defer = Deferred()
- self._update_ctx_manager.__aenter__ = Mock(return_value=enter_defer)
-
- # Start the background update.
- do_update_d = ensureDeferred(self.updates.do_next_background_update(True))
-
- self.pump()
-
- # `run_update` should have been called, but the update handler won't be
- # called until the `enter_defer` (returned by `__aenter__`) is resolved.
- self._on_update.assert_called_once_with(
- "test_update",
- "master",
- False,
- )
- self.assertFalse(do_update_d.called)
- self.assertFalse(self.update_deferred.called)
-
- # Resolving the `enter_defer` should call the update handler, which then
- # blocks.
- enter_defer.callback(100)
- self.pump()
- self.update_handler.assert_called_once_with({}, self._default_batch_size)
- self.assertFalse(self.update_deferred.called)
- self._update_ctx_manager.__aexit__.assert_not_called()
-
- # Resolving the update handler deferred should cause the
- # `do_next_background_update` to finish and return
- self.update_deferred.callback(100)
- self.pump()
- self._update_ctx_manager.__aexit__.assert_called()
- self.get_success(do_update_d)
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 7b7f6c349e..b31c5eb5ec 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -664,7 +664,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
):
iterations += 1
self.get_success(
- self.store.db_pool.updates.do_next_background_update(False), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# Ensure that we did actually take multiple iterations to process the
@@ -723,7 +723,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
):
iterations += 1
self.get_success(
- self.store.db_pool.updates.do_next_background_update(False), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# Ensure that we did actually take multiple iterations to process the
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index f8d11bac4e..d2b7b89952 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -13,35 +13,42 @@
# limitations under the License.
+from twisted.internet import defer
+
from synapse.types import UserID
from tests import unittest
+from tests.utils import setup_test_homeserver
-class DataStoreTestCase(unittest.HomeserverTestCase):
- def setUp(self) -> None:
- super(DataStoreTestCase, self).setUp()
+class DataStoreTestCase(unittest.TestCase):
+ @defer.inlineCallbacks
+ def setUp(self):
+ hs = yield setup_test_homeserver(self.addCleanup)
- self.store = self.hs.get_datastore()
+ self.store = hs.get_datastore()
self.user = UserID.from_string("@abcde:test")
self.displayname = "Frank"
- def test_get_users_paginate(self) -> None:
- self.get_success(self.store.register_user(self.user.to_string(), "pass"))
- self.get_success(self.store.create_profile(self.user.localpart))
- self.get_success(
+ @defer.inlineCallbacks
+ def test_get_users_paginate(self):
+ yield defer.ensureDeferred(
+ self.store.register_user(self.user.to_string(), "pass")
+ )
+ yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
+ yield defer.ensureDeferred(
self.store.set_profile_displayname(self.user.localpart, self.displayname)
)
- users, total = self.get_success(
+ users, total = yield defer.ensureDeferred(
self.store.get_users_paginate(0, 10, name="bc", guests=False)
)
self.assertEquals(1, total)
self.assertEquals(self.displayname, users.pop()["displayname"])
- users, total = self.get_success(
+ users, total = yield defer.ensureDeferred(
self.store.get_users_paginate(0, 10, name="BC", guests=False)
)
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 7f5b28aed8..37cf7bb232 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -23,7 +23,6 @@ from synapse.rest import admin
from synapse.rest.client import login, register, room
from synapse.server import HomeServer
from synapse.storage import DataStore
-from synapse.storage.background_updates import _BackgroundUpdateHandler
from synapse.storage.roommember import ProfileInfo
from synapse.util import Clock
@@ -392,9 +391,7 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
with mock.patch.dict(
self.store.db_pool.updates._background_update_handlers,
- populate_user_directory_process_users=_BackgroundUpdateHandler(
- mocked_process_users,
- ),
+ populate_user_directory_process_users=mocked_process_users,
):
self._purge_and_rebuild_user_dir()
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index e0b08d67d4..94b19788d7 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -13,30 +13,35 @@
# limitations under the License.
import logging
from typing import Optional
+from unittest.mock import Mock
+
+from twisted.internet import defer
+from twisted.internet.defer import succeed
from synapse.api.room_versions import RoomVersions
-from synapse.events import EventBase
-from synapse.types import JsonDict
+from synapse.events import FrozenEvent
from synapse.visibility import filter_events_for_server
-from tests import unittest
-from tests.utils import create_room
+import tests.unittest
+from tests.utils import create_room, setup_test_homeserver
logger = logging.getLogger(__name__)
TEST_ROOM_ID = "!TEST:ROOM"
-class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
- def setUp(self) -> None:
- super(FilterEventsForServerTestCase, self).setUp()
+class FilterEventsForServerTestCase(tests.unittest.TestCase):
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.hs = yield setup_test_homeserver(self.addCleanup)
self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory()
self.storage = self.hs.get_storage()
- self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
+ yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
- def test_filtering(self) -> None:
+ @defer.inlineCallbacks
+ def test_filtering(self):
#
# The events to be filtered consist of 10 membership events (it doesn't
# really matter if they are joins or leaves, so let's make them joins).
@@ -46,20 +51,18 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
#
# before we do that, we persist some other events to act as state.
- self.get_success(self._inject_visibility("@admin:hs", "joined"))
+ yield self.inject_visibility("@admin:hs", "joined")
for i in range(0, 10):
- self.get_success(self._inject_room_member("@resident%i:hs" % i))
+ yield self.inject_room_member("@resident%i:hs" % i)
events_to_filter = []
for i in range(0, 10):
user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server")
- evt = self.get_success(
- self._inject_room_member(user, extra_content={"a": "b"})
- )
+ evt = yield self.inject_room_member(user, extra_content={"a": "b"})
events_to_filter.append(evt)
- filtered = self.get_success(
+ filtered = yield defer.ensureDeferred(
filter_events_for_server(self.storage, "test_server", events_to_filter)
)
@@ -72,31 +75,34 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
self.assertEqual(filtered[i].content["a"], "b")
- def test_erased_user(self) -> None:
+ @defer.inlineCallbacks
+ def test_erased_user(self):
# 4 message events, from erased and unerased users, with a membership
# change in the middle of them.
events_to_filter = []
- evt = self.get_success(self._inject_message("@unerased:local_hs"))
+ evt = yield self.inject_message("@unerased:local_hs")
events_to_filter.append(evt)
- evt = self.get_success(self._inject_message("@erased:local_hs"))
+ evt = yield self.inject_message("@erased:local_hs")
events_to_filter.append(evt)
- evt = self.get_success(self._inject_room_member("@joiner:remote_hs"))
+ evt = yield self.inject_room_member("@joiner:remote_hs")
events_to_filter.append(evt)
- evt = self.get_success(self._inject_message("@unerased:local_hs"))
+ evt = yield self.inject_message("@unerased:local_hs")
events_to_filter.append(evt)
- evt = self.get_success(self._inject_message("@erased:local_hs"))
+ evt = yield self.inject_message("@erased:local_hs")
events_to_filter.append(evt)
# the erasey user gets erased
- self.get_success(self.hs.get_datastore().mark_user_erased("@erased:local_hs"))
+ yield defer.ensureDeferred(
+ self.hs.get_datastore().mark_user_erased("@erased:local_hs")
+ )
# ... and the filtering happens.
- filtered = self.get_success(
+ filtered = yield defer.ensureDeferred(
filter_events_for_server(self.storage, "test_server", events_to_filter)
)
@@ -117,7 +123,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
for i in (1, 4):
self.assertNotIn("body", filtered[i].content)
- def _inject_visibility(self, user_id: str, visibility: str) -> EventBase:
+ @defer.inlineCallbacks
+ def inject_visibility(self, user_id, visibility):
content = {"history_visibility": visibility}
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
@@ -130,18 +137,18 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, context = yield defer.ensureDeferred(
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ yield defer.ensureDeferred(
+ self.storage.persistence.persist_event(event, context)
+ )
return event
- def _inject_room_member(
- self,
- user_id: str,
- membership: str = "join",
- extra_content: Optional[JsonDict] = None,
- ) -> EventBase:
+ @defer.inlineCallbacks
+ def inject_room_member(
+ self, user_id, membership="join", extra_content: Optional[dict] = None
+ ):
content = {"membership": membership}
content.update(extra_content or {})
builder = self.event_builder_factory.for_room_version(
@@ -155,16 +162,17 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, context = yield defer.ensureDeferred(
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ yield defer.ensureDeferred(
+ self.storage.persistence.persist_event(event, context)
+ )
return event
- def _inject_message(
- self, user_id: str, content: Optional[JsonDict] = None
- ) -> EventBase:
+ @defer.inlineCallbacks
+ def inject_message(self, user_id, content=None):
if content is None:
content = {"body": "testytest", "msgtype": "m.text"}
builder = self.event_builder_factory.for_room_version(
@@ -177,9 +185,164 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, context = yield defer.ensureDeferred(
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ yield defer.ensureDeferred(
+ self.storage.persistence.persist_event(event, context)
+ )
return event
+
+ @defer.inlineCallbacks
+ def test_large_room(self):
+ # see what happens when we have a large room with hundreds of thousands
+ # of membership events
+
+ # As above, the events to be filtered consist of 10 membership events,
+ # where one of them is for a user on the server we are filtering for.
+
+ import cProfile
+ import pstats
+ import time
+
+ # we stub out the store, because building up all that state the normal
+ # way is very slow.
+ test_store = _TestStore()
+
+ # our initial state is 100000 membership events and one
+ # history_visibility event.
+ room_state = []
+
+ history_visibility_evt = FrozenEvent(
+ {
+ "event_id": "$history_vis",
+ "type": "m.room.history_visibility",
+ "sender": "@resident_user_0:test.com",
+ "state_key": "",
+ "room_id": TEST_ROOM_ID,
+ "content": {"history_visibility": "joined"},
+ }
+ )
+ room_state.append(history_visibility_evt)
+ test_store.add_event(history_visibility_evt)
+
+ for i in range(0, 100000):
+ user = "@resident_user_%i:test.com" % (i,)
+ evt = FrozenEvent(
+ {
+ "event_id": "$res_event_%i" % (i,),
+ "type": "m.room.member",
+ "state_key": user,
+ "sender": user,
+ "room_id": TEST_ROOM_ID,
+ "content": {"membership": "join", "extra": "zzz,"},
+ }
+ )
+ room_state.append(evt)
+ test_store.add_event(evt)
+
+ events_to_filter = []
+ for i in range(0, 10):
+ user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server")
+ evt = FrozenEvent(
+ {
+ "event_id": "$evt%i" % (i,),
+ "type": "m.room.member",
+ "state_key": user,
+ "sender": user,
+ "room_id": TEST_ROOM_ID,
+ "content": {"membership": "join", "extra": "zzz"},
+ }
+ )
+ events_to_filter.append(evt)
+ room_state.append(evt)
+
+ test_store.add_event(evt)
+ test_store.set_state_ids_for_event(
+ evt, {(e.type, e.state_key): e.event_id for e in room_state}
+ )
+
+ pr = cProfile.Profile()
+ pr.enable()
+
+ logger.info("Starting filtering")
+ start = time.time()
+
+ storage = Mock()
+ storage.main = test_store
+ storage.state = test_store
+
+ filtered = yield defer.ensureDeferred(
+ filter_events_for_server(test_store, "test_server", events_to_filter)
+ )
+ logger.info("Filtering took %f seconds", time.time() - start)
+
+ pr.disable()
+ with open("filter_events_for_server.profile", "w+") as f:
+ ps = pstats.Stats(pr, stream=f).sort_stats("cumulative")
+ ps.print_stats()
+
+ # the result should be 5 redacted events, and 5 unredacted events.
+ for i in range(0, 5):
+ self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
+ self.assertNotIn("extra", filtered[i].content)
+
+ for i in range(5, 10):
+ self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
+ self.assertEqual(filtered[i].content["extra"], "zzz")
+
+ test_large_room.skip = "Disabled by default because it's slow"
+
+
+class _TestStore:
+ """Implements a few methods of the DataStore, so that we can test
+ filter_events_for_server
+
+ """
+
+ def __init__(self):
+ # data for get_events: a map from event_id to event
+ self.events = {}
+
+ # data for get_state_ids_for_events mock: a map from event_id to
+ # a map from (type_state_key) -> event_id for the state at that
+ # event
+ self.state_ids_for_events = {}
+
+ def add_event(self, event):
+ self.events[event.event_id] = event
+
+ def set_state_ids_for_event(self, event, state):
+ self.state_ids_for_events[event.event_id] = state
+
+ def get_state_ids_for_events(self, events, types):
+ res = {}
+ include_memberships = False
+ for (type, state_key) in types:
+ if type == "m.room.history_visibility":
+ continue
+ if type != "m.room.member" or state_key is not None:
+ raise RuntimeError(
+ "Unimplemented: get_state_ids with type (%s, %s)"
+ % (type, state_key)
+ )
+ include_memberships = True
+
+ if include_memberships:
+ for event_id in events:
+ res[event_id] = self.state_ids_for_events[event_id]
+
+ else:
+ k = ("m.room.history_visibility", "")
+ for event_id in events:
+ hve = self.state_ids_for_events[event_id][k]
+ res[event_id] = {k: hve}
+
+ return succeed(res)
+
+ def get_events(self, events):
+ return succeed({event_id: self.events[event_id] for event_id in events})
+
+ def are_users_erased(self, users):
+ return succeed({u: False for u in users})
diff --git a/tests/unittest.py b/tests/unittest.py
index eea0903f05..165aafc574 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -331,16 +331,17 @@ class HomeserverTestCase(TestCase):
time.sleep(0.01)
def wait_for_background_updates(self) -> None:
- """Block until all background database updates have completed.
+ """
+ Block until all background database updates have completed.
- Note that callers must ensure there's a store property created on the
+ Note that callers must ensure that's a store property created on the
testcase.
"""
while not self.get_success(
self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db_pool.updates.do_next_background_update(False), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
def make_homeserver(self, reactor, clock):
@@ -499,7 +500,8 @@ class HomeserverTestCase(TestCase):
async def run_bg_updates():
with LoggingContext("run_bg_updates"):
- self.get_success(stor.db_pool.updates.run_background_updates(False))
+ while not await stor.db_pool.updates.has_completed_background_updates():
+ await stor.db_pool.updates.do_next_background_update(1)
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
stor = hs.get_datastore()
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 291644eb7d..6578f3411e 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -13,7 +13,6 @@
# limitations under the License.
-from typing import List
from unittest.mock import Mock
from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries
@@ -262,17 +261,6 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase):
self.assertEquals(cache["key4"], [4])
self.assertEquals(cache["key5"], [5, 6])
- def test_zero_size_drop_from_cache(self) -> None:
- """Test that `drop_from_cache` works correctly with 0-sized entries."""
- cache: LruCache[str, List[int]] = LruCache(5, size_callback=lambda x: 0)
- cache["key1"] = []
-
- self.assertEqual(len(cache), 0)
- cache.cache["key1"].drop_from_cache()
- self.assertIsNone(
- cache.pop("key1"), "Cache entry should have been evicted but wasn't"
- )
-
class TimeEvictionTestCase(unittest.HomeserverTestCase):
"""Test that time based eviction works correctly."""
|