diff options
Diffstat (limited to 'tests')
35 files changed, 1401 insertions, 3257 deletions
diff --git a/tests/app/test_homeserver_start.py b/tests/app/test_homeserver_start.py deleted file mode 100644 index cbcada0451..0000000000 --- a/tests/app/test_homeserver_start.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import synapse.app.homeserver -from synapse.config._base import ConfigError - -from tests.config.utils import ConfigFileTestCase - - -class HomeserverAppStartTestCase(ConfigFileTestCase): - def test_wrong_start_caught(self): - # Generate a config with a worker_app - self.generate_config() - # Add a blank line as otherwise the next addition ends up on a line with a comment - self.add_lines_to_config([" "]) - self.add_lines_to_config(["worker_app: test_worker_app"]) - - # Ensure that starting master process with worker config raises an exception - with self.assertRaises(ConfigError): - synapse.app.homeserver.setup(["-c", self.config_file]) diff --git a/tests/config/test_registration_config.py b/tests/config/test_registration_config.py deleted file mode 100644 index 17a84d20d8..0000000000 --- a/tests/config/test_registration_config.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from synapse.config import ConfigError -from synapse.config.homeserver import HomeServerConfig - -from tests.unittest import TestCase -from tests.utils import default_config - - -class RegistrationConfigTestCase(TestCase): - def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self): - """ - session_lifetime should logically be larger than, or at least as large as, - all the different token lifetimes. - Test that the user is faced with configuration errors if they make it - smaller, as that configuration doesn't make sense. - """ - config_dict = default_config("test") - - # First test all the error conditions - with self.assertRaises(ConfigError): - HomeServerConfig().parse_config_dict( - { - "session_lifetime": "30m", - "nonrefreshable_access_token_lifetime": "31m", - **config_dict, - } - ) - - with self.assertRaises(ConfigError): - HomeServerConfig().parse_config_dict( - { - "session_lifetime": "30m", - "refreshable_access_token_lifetime": "31m", - **config_dict, - } - ) - - with self.assertRaises(ConfigError): - HomeServerConfig().parse_config_dict( - { - "session_lifetime": "30m", - "refresh_token_lifetime": "31m", - **config_dict, - } - ) - - # Then test all the fine conditions - HomeServerConfig().parse_config_dict( - { - "session_lifetime": "31m", - "nonrefreshable_access_token_lifetime": "31m", - **config_dict, - } - ) - - HomeServerConfig().parse_config_dict( - { - "session_lifetime": "31m", - "refreshable_access_token_lifetime": "31m", - **config_dict, - } - ) - - HomeServerConfig().parse_config_dict( - {"session_lifetime": "31m", "refresh_token_lifetime": "31m", **config_dict} - ) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 17a9fb63a1..4d1e154578 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -22,7 +22,6 @@ import signedjson.sign from nacl.signing import SigningKey from signedjson.key import encode_verify_key_base64, get_verify_key -from twisted.internet import defer from twisted.internet.defer import Deferred, ensureDeferred from synapse.api.errors import SynapseError @@ -578,76 +577,6 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) ) - def test_get_multiple_keys_from_perspectives(self): - """Check that we can correctly request multiple keys for the same server""" - - fetcher = PerspectivesKeyFetcher(self.hs) - - SERVER_NAME = "server2" - - testkey1 = signedjson.key.generate_signing_key("ver1") - testverifykey1 = signedjson.key.get_verify_key(testkey1) - testverifykey1_id = "ed25519:ver1" - - testkey2 = signedjson.key.generate_signing_key("ver2") - testverifykey2 = signedjson.key.get_verify_key(testkey2) - testverifykey2_id = "ed25519:ver2" - - VALID_UNTIL_TS = 200 * 1000 - - response1 = self.build_perspectives_response( - SERVER_NAME, - testkey1, - VALID_UNTIL_TS, - ) - response2 = self.build_perspectives_response( - SERVER_NAME, - testkey2, - VALID_UNTIL_TS, - ) - - async def post_json(destination, path, data, **kwargs): - self.assertEqual(destination, self.mock_perspective_server.server_name) - self.assertEqual(path, "/_matrix/key/v2/query") - - # check that the request is for the expected keys - q = data["server_keys"] - - self.assertEqual( - list(q[SERVER_NAME].keys()), [testverifykey1_id, testverifykey2_id] - ) - return {"server_keys": [response1, response2]} - - self.http_client.post_json.side_effect = post_json - - # fire off two separate requests; they should get merged together into a - # single HTTP hit. - request1_d = defer.ensureDeferred( - fetcher.get_keys(SERVER_NAME, [testverifykey1_id], 0) - ) - request2_d = defer.ensureDeferred( - fetcher.get_keys(SERVER_NAME, [testverifykey2_id], 0) - ) - - keys1 = self.get_success(request1_d) - self.assertIn(testverifykey1_id, keys1) - k = keys1[testverifykey1_id] - self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) - self.assertEqual(k.verify_key, testverifykey1) - self.assertEqual(k.verify_key.alg, "ed25519") - self.assertEqual(k.verify_key.version, "ver1") - - keys2 = self.get_success(request2_d) - self.assertIn(testverifykey2_id, keys2) - k = keys2[testverifykey2_id] - self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) - self.assertEqual(k.verify_key, testverifykey2) - self.assertEqual(k.verify_key.alg, "ed25519") - self.assertEqual(k.verify_key.version, "ver2") - - # finally, ensure that only one request was sent - self.assertEqual(self.http_client.post_json.call_count, 1) - def test_get_perspectives_own_key(self): """Check that we can get the perspectives server's own keys diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py deleted file mode 100644 index a7031a55f2..0000000000 --- a/tests/federation/transport/test_client.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -from synapse.api.room_versions import RoomVersions -from synapse.federation.transport.client import SendJoinParser - -from tests.unittest import TestCase - - -class SendJoinParserTestCase(TestCase): - def test_two_writes(self) -> None: - """Test that the parser can sensibly deserialise an input given in two slices.""" - parser = SendJoinParser(RoomVersions.V1, True) - parent_event = { - "content": { - "see_room_version_spec": "The event format changes depending on the room version." - }, - "event_id": "$authparent", - "room_id": "!somewhere:example.org", - "type": "m.room.minimal_pdu", - } - state = { - "content": { - "see_room_version_spec": "The event format changes depending on the room version." - }, - "event_id": "$DoNotThinkAboutTheEvent", - "room_id": "!somewhere:example.org", - "type": "m.room.minimal_pdu", - } - response = [ - 200, - { - "auth_chain": [parent_event], - "origin": "matrix.org", - "state": [state], - }, - ] - serialised_response = json.dumps(response).encode() - - # Send data to the parser - parser.write(serialised_response[:100]) - parser.write(serialised_response[100:]) - - # Retrieve the parsed SendJoinResponse - parsed_response = parser.finish() - - # Sanity check the parsing gave us sensible data. - self.assertEqual(len(parsed_response.auth_events), 1, parsed_response) - self.assertEqual(len(parsed_response.state), 1, parsed_response) - self.assertEqual(parsed_response.event_dict, {}, parsed_response) - self.assertIsNone(parsed_response.event, parsed_response) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 03b8b8615c..72e176da75 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -71,7 +71,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_short_term_login_token_gives_user_id(self): token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", duration_in_ms=5000 + self.user1, "", 5000 ) res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) self.assertEqual(self.user1, res.user_id) @@ -94,7 +94,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_short_term_login_token_cannot_replace_user_id(self): token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", duration_in_ms=5000 + self.user1, "", 5000 ) macaroon = pymacaroons.Macaroon.deserialize(token) @@ -213,6 +213,6 @@ class AuthTestCase(unittest.HomeserverTestCase): def _get_macaroon(self): token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", duration_in_ms=5000 + self.user1, "", 5000 ) return pymacaroons.Macaroon.deserialize(token) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index 8705ff8943..b625995d12 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -66,13 +66,7 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "cas", - request, - "redirect_uri", - None, - new_user=True, - auth_provider_session_id=None, + "@test_user:test", "cas", request, "redirect_uri", None, new_user=True ) def test_map_cas_user_to_existing_user(self): @@ -95,13 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "cas", - request, - "redirect_uri", - None, - new_user=False, - auth_provider_session_id=None, + "@test_user:test", "cas", request, "redirect_uri", None, new_user=False ) # Subsequent calls should map to the same mxid. @@ -110,13 +98,7 @@ class CasHandlerTestCase(HomeserverTestCase): self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") ) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "cas", - request, - "redirect_uri", - None, - new_user=False, - auth_provider_session_id=None, + "@test_user:test", "cas", request, "redirect_uri", None, new_user=False ) def test_map_cas_user_to_invalid_localpart(self): @@ -134,13 +116,7 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@f=c3=b6=c3=b6:test", - "cas", - request, - "redirect_uri", - None, - new_user=True, - auth_provider_session_id=None, + "@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True ) @override_config( @@ -184,13 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "cas", - request, - "redirect_uri", - None, - new_user=True, - auth_provider_session_id=None, + "@test_user:test", "cas", request, "redirect_uri", None, new_user=True ) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index cfe3de5266..a25c89bd5b 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -252,6 +252,13 @@ class OidcHandlerTestCase(HomeserverTestCase): with patch.object(self.provider, "load_metadata", patched_load_metadata): self.get_failure(self.provider.load_jwks(force=True), RuntimeError) + # Return empty key set if JWKS are not used + self.provider._scopes = [] # not asking the openid scope + self.http_client.get_json.reset_mock() + jwks = self.get_success(self.provider.load_jwks(force=True)) + self.http_client.get_json.assert_not_called() + self.assertEqual(jwks, {"keys": []}) + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_validate_config(self): """Provider metadatas are extensively validated.""" @@ -448,13 +455,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, - "oidc", - request, - client_redirect_url, - None, - new_user=True, - auth_provider_session_id=None, + expected_user_id, "oidc", request, client_redirect_url, None, new_user=True ) self.provider._exchange_code.assert_called_once_with(code) self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) @@ -481,58 +482,17 @@ class OidcHandlerTestCase(HomeserverTestCase): self.provider._fetch_userinfo.reset_mock() # With userinfo fetching - self.provider._user_profile_method = "userinfo_endpoint" - token = { - "type": "bearer", - "access_token": "access_token", - } - self.provider._exchange_code = simple_async_mock(return_value=token) + self.provider._scopes = [] # do not ask the "openid" scope self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, - "oidc", - request, - client_redirect_url, - None, - new_user=False, - auth_provider_session_id=None, + expected_user_id, "oidc", request, client_redirect_url, None, new_user=False ) self.provider._exchange_code.assert_called_once_with(code) self.provider._parse_id_token.assert_not_called() self.provider._fetch_userinfo.assert_called_once_with(token) self.render_error.assert_not_called() - # With an ID token, userinfo fetching and sid in the ID token - self.provider._user_profile_method = "userinfo_endpoint" - token = { - "type": "bearer", - "access_token": "access_token", - "id_token": "id_token", - } - id_token = { - "sid": "abcdefgh", - } - self.provider._parse_id_token = simple_async_mock(return_value=id_token) - self.provider._exchange_code = simple_async_mock(return_value=token) - auth_handler.complete_sso_login.reset_mock() - self.provider._fetch_userinfo.reset_mock() - self.get_success(self.handler.handle_oidc_callback(request)) - - auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, - "oidc", - request, - client_redirect_url, - None, - new_user=False, - auth_provider_session_id=id_token["sid"], - ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.provider._fetch_userinfo.assert_called_once_with(token) - self.render_error.assert_not_called() - # Handle userinfo fetching error self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) self.get_success(self.handler.handle_oidc_callback(request)) @@ -816,7 +776,6 @@ class OidcHandlerTestCase(HomeserverTestCase): client_redirect_url, {"phone": "1234567"}, new_user=True, - auth_provider_session_id=None, ) @override_config({"oidc_config": DEFAULT_CONFIG}) @@ -831,13 +790,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "oidc", - ANY, - ANY, - None, - new_user=True, - auth_provider_session_id=None, + "@test_user:test", "oidc", ANY, ANY, None, new_user=True ) auth_handler.complete_sso_login.reset_mock() @@ -848,13 +801,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user_2:test", - "oidc", - ANY, - ANY, - None, - new_user=True, - auth_provider_session_id=None, + "@test_user_2:test", "oidc", ANY, ANY, None, new_user=True ) auth_handler.complete_sso_login.reset_mock() @@ -891,26 +838,14 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), - "oidc", - ANY, - ANY, - None, - new_user=False, - auth_provider_session_id=None, + user.to_string(), "oidc", ANY, ANY, None, new_user=False ) auth_handler.complete_sso_login.reset_mock() # Subsequent calls should map to the same mxid. self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), - "oidc", - ANY, - ANY, - None, - new_user=False, - auth_provider_session_id=None, + user.to_string(), "oidc", ANY, ANY, None, new_user=False ) auth_handler.complete_sso_login.reset_mock() @@ -925,13 +860,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), - "oidc", - ANY, - ANY, - None, - new_user=False, - auth_provider_session_id=None, + user.to_string(), "oidc", ANY, ANY, None, new_user=False ) auth_handler.complete_sso_login.reset_mock() @@ -967,13 +896,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@TEST_USER_2:test", - "oidc", - ANY, - ANY, - None, - new_user=False, - auth_provider_session_id=None, + "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False ) @override_config({"oidc_config": DEFAULT_CONFIG}) @@ -1011,13 +934,7 @@ class OidcHandlerTestCase(HomeserverTestCase): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", - "oidc", - ANY, - ANY, - None, - new_user=True, - auth_provider_session_id=None, + "@test_user1:test", "oidc", ANY, ANY, None, new_user=True ) auth_handler.complete_sso_login.reset_mock() @@ -1101,13 +1018,7 @@ class OidcHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@tester:test", - "oidc", - ANY, - ANY, - None, - new_user=True, - auth_provider_session_id=None, + "@tester:test", "oidc", ANY, ANY, None, new_user=True ) @override_config( @@ -1132,13 +1043,7 @@ class OidcHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@tester:test", - "oidc", - ANY, - ANY, - None, - new_user=True, - auth_provider_session_id=None, + "@tester:test", "oidc", ANY, ANY, None, new_user=True ) @override_config( @@ -1251,7 +1156,7 @@ async def _make_callback_with_userinfo( handler = hs.get_oidc_handler() provider = handler._providers["oidc"] - provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) + provider._exchange_code = simple_async_mock(return_value={}) provider._parse_id_token = simple_async_mock(return_value=userinfo) provider._fetch_userinfo = simple_async_mock(return_value=userinfo) diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index e5a6a6c747..7b95844b55 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -32,7 +32,7 @@ from synapse.handlers.room_summary import _child_events_comparison_key, _RoomEnt from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer -from synapse.types import JsonDict, UserID, create_requester +from synapse.types import JsonDict, UserID from tests import unittest @@ -249,7 +249,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self._assert_rooms(result, expected) result = self.get_success( - self.handler.get_room_hierarchy(create_requester(self.user), self.space) + self.handler.get_room_hierarchy(self.user, self.space) ) self._assert_hierarchy(result, expected) @@ -263,9 +263,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): expected = [(self.space, [self.room]), (self.room, ())] self._assert_rooms(result, expected) - result = self.get_success( - self.handler.get_room_hierarchy(create_requester(user2), self.space) - ) + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) self._assert_hierarchy(result, expected) # If the space is made invite-only, it should no longer be viewable. @@ -276,10 +274,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): tok=self.token, ) self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) - self.get_failure( - self.handler.get_room_hierarchy(create_requester(user2), self.space), - AuthError, - ) + self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) # If the space is made world-readable it should return a result. self.helper.send_state( @@ -291,9 +286,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): result = self.get_success(self.handler.get_space_summary(user2, self.space)) self._assert_rooms(result, expected) - result = self.get_success( - self.handler.get_room_hierarchy(create_requester(user2), self.space) - ) + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) self._assert_hierarchy(result, expected) # Make it not world-readable again and confirm it results in an error. @@ -304,10 +297,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): tok=self.token, ) self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) - self.get_failure( - self.handler.get_room_hierarchy(create_requester(user2), self.space), - AuthError, - ) + self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) # Join the space and results should be returned. self.helper.invite(self.space, targ=user2, tok=self.token) @@ -315,9 +305,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): result = self.get_success(self.handler.get_space_summary(user2, self.space)) self._assert_rooms(result, expected) - result = self.get_success( - self.handler.get_room_hierarchy(create_requester(user2), self.space) - ) + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) self._assert_hierarchy(result, expected) # Attempting to view an unknown room returns the same error. @@ -326,9 +314,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): AuthError, ) self.get_failure( - self.handler.get_room_hierarchy( - create_requester(user2), "#not-a-space:" + self.hs.hostname - ), + self.handler.get_room_hierarchy(user2, "#not-a-space:" + self.hs.hostname), AuthError, ) @@ -336,10 +322,10 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): """In-flight room hierarchy requests are deduplicated.""" # Run two `get_room_hierarchy` calls up until they block. deferred1 = ensureDeferred( - self.handler.get_room_hierarchy(create_requester(self.user), self.space) + self.handler.get_room_hierarchy(self.user, self.space) ) deferred2 = ensureDeferred( - self.handler.get_room_hierarchy(create_requester(self.user), self.space) + self.handler.get_room_hierarchy(self.user, self.space) ) # Complete the two calls. @@ -354,7 +340,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # A subsequent `get_room_hierarchy` call should not reuse the result. result3 = self.get_success( - self.handler.get_room_hierarchy(create_requester(self.user), self.space) + self.handler.get_room_hierarchy(self.user, self.space) ) self._assert_hierarchy(result3, expected) self.assertIsNot(result1, result3) @@ -373,11 +359,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # Run two `get_room_hierarchy` calls for different users up until they block. deferred1 = ensureDeferred( - self.handler.get_room_hierarchy(create_requester(self.user), self.space) - ) - deferred2 = ensureDeferred( - self.handler.get_room_hierarchy(create_requester(user2), self.space) + self.handler.get_room_hierarchy(self.user, self.space) ) + deferred2 = ensureDeferred(self.handler.get_room_hierarchy(user2, self.space)) # Complete the two calls. result1 = self.get_success(deferred1) @@ -481,9 +465,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ] self._assert_rooms(result, expected) - result = self.get_success( - self.handler.get_room_hierarchy(create_requester(user2), self.space) - ) + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) self._assert_hierarchy(result, expected) def test_complex_space(self): @@ -525,7 +507,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self._assert_rooms(result, expected) result = self.get_success( - self.handler.get_room_hierarchy(create_requester(self.user), self.space) + self.handler.get_room_hierarchy(self.user, self.space) ) self._assert_hierarchy(result, expected) @@ -540,9 +522,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): room_ids.append(self.room) result = self.get_success( - self.handler.get_room_hierarchy( - create_requester(self.user), self.space, limit=7 - ) + self.handler.get_room_hierarchy(self.user, self.space, limit=7) ) # The result should have the space and all of the links, plus some of the # rooms and a pagination token. @@ -554,10 +534,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # Check the next page. result = self.get_success( self.handler.get_room_hierarchy( - create_requester(self.user), - self.space, - limit=5, - from_token=result["next_batch"], + self.user, self.space, limit=5, from_token=result["next_batch"] ) ) # The result should have the space and the room in it, along with a link @@ -577,22 +554,20 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): room_ids.append(self.room) result = self.get_success( - self.handler.get_room_hierarchy( - create_requester(self.user), self.space, limit=7 - ) + self.handler.get_room_hierarchy(self.user, self.space, limit=7) ) self.assertIn("next_batch", result) # Changing the room ID, suggested-only, or max-depth causes an error. self.get_failure( self.handler.get_room_hierarchy( - create_requester(self.user), self.room, from_token=result["next_batch"] + self.user, self.room, from_token=result["next_batch"] ), SynapseError, ) self.get_failure( self.handler.get_room_hierarchy( - create_requester(self.user), + self.user, self.space, suggested_only=True, from_token=result["next_batch"], @@ -601,19 +576,14 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) self.get_failure( self.handler.get_room_hierarchy( - create_requester(self.user), - self.space, - max_depth=0, - from_token=result["next_batch"], + self.user, self.space, max_depth=0, from_token=result["next_batch"] ), SynapseError, ) # An invalid token is ignored. self.get_failure( - self.handler.get_room_hierarchy( - create_requester(self.user), self.space, from_token="foo" - ), + self.handler.get_room_hierarchy(self.user, self.space, from_token="foo"), SynapseError, ) @@ -639,18 +609,14 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # Test just the space itself. result = self.get_success( - self.handler.get_room_hierarchy( - create_requester(self.user), self.space, max_depth=0 - ) + self.handler.get_room_hierarchy(self.user, self.space, max_depth=0) ) expected: List[Tuple[str, Iterable[str]]] = [(spaces[0], [rooms[0], spaces[1]])] self._assert_hierarchy(result, expected) # A single additional layer. result = self.get_success( - self.handler.get_room_hierarchy( - create_requester(self.user), self.space, max_depth=1 - ) + self.handler.get_room_hierarchy(self.user, self.space, max_depth=1) ) expected += [ (rooms[0], ()), @@ -660,9 +626,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # A few layers. result = self.get_success( - self.handler.get_room_hierarchy( - create_requester(self.user), self.space, max_depth=3 - ) + self.handler.get_room_hierarchy(self.user, self.space, max_depth=3) ) expected += [ (rooms[1], ()), @@ -693,7 +657,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self._assert_rooms(result, expected) result = self.get_success( - self.handler.get_room_hierarchy(create_requester(self.user), self.space) + self.handler.get_room_hierarchy(self.user, self.space) ) self._assert_hierarchy(result, expected) @@ -775,7 +739,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): new=summarize_remote_room_hierarchy, ): result = self.get_success( - self.handler.get_room_hierarchy(create_requester(self.user), self.space) + self.handler.get_room_hierarchy(self.user, self.space) ) self._assert_hierarchy(result, expected) @@ -942,7 +906,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): new=summarize_remote_room_hierarchy, ): result = self.get_success( - self.handler.get_room_hierarchy(create_requester(self.user), self.space) + self.handler.get_room_hierarchy(self.user, self.space) ) self._assert_hierarchy(result, expected) @@ -1000,7 +964,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): new=summarize_remote_room_hierarchy, ): result = self.get_success( - self.handler.get_room_hierarchy(create_requester(self.user), self.space) + self.handler.get_room_hierarchy(self.user, self.space) ) self._assert_hierarchy(result, expected) diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 50551aa6e3..8cfc184fef 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -130,13 +130,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "saml", - request, - "redirect_uri", - None, - new_user=True, - auth_provider_session_id=None, + "@test_user:test", "saml", request, "redirect_uri", None, new_user=True ) @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) @@ -162,13 +156,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "saml", - request, - "", - None, - new_user=False, - auth_provider_session_id=None, + "@test_user:test", "saml", request, "", None, new_user=False ) # Subsequent calls should map to the same mxid. @@ -177,13 +165,7 @@ class SamlHandlerTestCase(HomeserverTestCase): self.handler._handle_authn_response(request, saml_response, "") ) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "saml", - request, - "", - None, - new_user=False, - auth_provider_session_id=None, + "@test_user:test", "saml", request, "", None, new_user=False ) def test_map_saml_response_to_invalid_localpart(self): @@ -231,13 +213,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", - "saml", - request, - "", - None, - new_user=True, - auth_provider_session_id=None, + "@test_user1:test", "saml", request, "", None, new_user=True ) auth_handler.complete_sso_login.reset_mock() @@ -333,13 +309,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "saml", - request, - "redirect_uri", - None, - new_user=True, - auth_provider_session_id=None, + "@test_user:test", "saml", request, "redirect_uri", None, new_user=True ) diff --git a/tests/push/test_email.py b/tests/push/test_email.py index f8cba7b645..90f800e564 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -128,7 +128,6 @@ class EmailPusherTests(HomeserverTestCase): ) self.auth_handler = hs.get_auth_handler() - self.store = hs.get_datastore() def test_need_validated_email(self): """Test that we can only add an email pusher if the user has validated @@ -409,7 +408,13 @@ class EmailPusherTests(HomeserverTestCase): self.hs.get_datastore().db_pool.updates._all_done = False # Now let's actually drive the updates to completion - self.wait_for_background_updates() + while not self.get_success( + self.hs.get_datastore().db_pool.updates.has_completed_background_updates() + ): + self.get_success( + self.hs.get_datastore().db_pool.updates.do_next_background_update(100), + by=0.1, + ) # Check that all pushers with unlinked addresses were deleted pushers = self.get_success( diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 596ba5a0c9..0a6e4795ee 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -17,7 +17,6 @@ from unittest.mock import patch from synapse.api.room_versions import RoomVersion from synapse.rest import admin from synapse.rest.client import login, room, sync -from synapse.storage.util.id_generators import MultiWriterIdGenerator from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import make_request @@ -194,10 +193,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): # # Worker2's event stream position will not advance until we call # __aexit__ again. - worker_store2 = worker_hs2.get_datastore() - assert isinstance(worker_store2._stream_id_gen, MultiWriterIdGenerator) - - actx = worker_store2._stream_id_gen.get_next() + actx = worker_hs2.get_datastore()._stream_id_gen.get_next() self.get_success(actx.__aenter__()) response = self.helper.send(room_id1, body="Hi!", tok=self.other_access_token) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 3adadcb46b..af849bd471 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import urllib.parse -from http import HTTPStatus from unittest.mock import Mock from twisted.internet.defer import Deferred @@ -41,7 +41,7 @@ class VersionTestCase(unittest.HomeserverTestCase): def test_version_string(self): channel = self.make_request("GET", self.url, shorthand=False) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( {"server_version", "python_version"}, set(channel.json_body.keys()) ) @@ -70,11 +70,11 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): content={"localpart": "test"}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) group_id = channel.json_body["group_id"] - self._check_group(group_id, expect_code=HTTPStatus.OK) + self._check_group(group_id, expect_code=200) # Invite/join another user @@ -82,13 +82,13 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={} ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) url = "/groups/%s/self/accept_invite" % (group_id,) channel = self.make_request( "PUT", url.encode("ascii"), access_token=self.other_user_token, content={} ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Check other user knows they're in the group self.assertIn(group_id, self._get_groups_user_is_in(self.admin_user_tok)) @@ -103,10 +103,10 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): content={"localpart": "test"}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - # Check group returns HTTPStatus.NOT_FOUND - self._check_group(group_id, expect_code=HTTPStatus.NOT_FOUND) + # Check group returns 404 + self._check_group(group_id, expect_code=404) # Check users don't think they're in the group self.assertNotIn(group_id, self._get_groups_user_is_in(self.admin_user_tok)) @@ -122,13 +122,15 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): "GET", url.encode("ascii"), access_token=self.admin_user_tok ) - self.assertEqual(expect_code, channel.code, msg=channel.json_body) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] + ) def _get_groups_user_is_in(self, access_token): """Returns the list of groups the user is in (given their access token)""" channel = self.make_request("GET", b"/joined_groups", access_token=access_token) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) return channel.json_body["groups"] @@ -208,10 +210,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Should be quarantined self.assertEqual( - HTTPStatus.NOT_FOUND, - channel.code, + 404, + int(channel.code), msg=( - "Expected to receive a HTTPStatus.NOT_FOUND on accessing quarantined media: %s" + "Expected to receive a 404 on accessing quarantined media: %s" % server_and_media_id ), ) @@ -230,8 +232,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Expect a forbidden error self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, + 403, + int(channel.result["code"]), msg="Expected forbidden on quarantining media as a non-admin", ) @@ -245,8 +247,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Expect a forbidden error self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, + 403, + int(channel.result["code"]), msg="Expected forbidden on quarantining media as a non-admin", ) @@ -277,7 +279,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): ) # Should be successful - self.assertEqual(HTTPStatus.OK, channel.code) + self.assertEqual(200, int(channel.code), msg=channel.result["body"]) # Quarantine the media url = "/_synapse/admin/v1/media/quarantine/%s/%s" % ( @@ -290,7 +292,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): access_token=admin_user_tok, ) self.pump(1.0) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.code), msg=channel.result["body"]) # Attempt to access the media self._ensure_quarantined(admin_user_tok, server_name_and_media_id) @@ -346,9 +348,11 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): access_token=admin_user_tok, ) self.pump(1.0) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.code), msg=channel.result["body"]) self.assertEqual( - channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items" + json.loads(channel.result["body"].decode("utf-8")), + {"num_quarantined": 2}, + "Expected 2 quarantined items", ) # Convert mxc URLs to server/media_id strings @@ -392,9 +396,11 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): access_token=admin_user_tok, ) self.pump(1.0) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( - channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items" + json.loads(channel.result["body"].decode("utf-8")), + {"num_quarantined": 2}, + "Expected 2 quarantined items", ) # Attempt to access each piece of media @@ -426,7 +432,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),) channel = self.make_request("POST", url, access_token=admin_user_tok) self.pump(1.0) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.code), msg=channel.result["body"]) # Quarantine all media by this user url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( @@ -438,9 +444,11 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): access_token=admin_user_tok, ) self.pump(1.0) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( - channel.json_body, {"num_quarantined": 1}, "Expected 1 quarantined item" + json.loads(channel.result["body"].decode("utf-8")), + {"num_quarantined": 1}, + "Expected 1 quarantined item", ) # Attempt to access each piece of media, the first should fail, the @@ -459,10 +467,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Shouldn't be quarantined self.assertEqual( - HTTPStatus.OK, - channel.code, + 200, + int(channel.code), msg=( - "Expected to receive a HTTPStatus.OK on accessing not-quarantined media: %s" + "Expected to receive a 200 on accessing not-quarantined media: %s" % server_and_media_id_2 ), ) @@ -491,7 +499,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase): def test_purge_history(self): """ Simple test of purge history API. - Test only that is is possible to call, get status HTTPStatus.OK and purge_id. + Test only that is is possible to call, get status 200 and purge_id. """ channel = self.make_request( @@ -501,7 +509,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("purge_id", channel.json_body) purge_id = channel.json_body["purge_id"] @@ -512,5 +520,5 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("complete", channel.json_body["status"]) diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index 4d152c0d66..cd5c60b65c 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -16,14 +16,11 @@ from typing import Collection from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor - import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login from synapse.server import HomeServer from synapse.storage.background_updates import BackgroundUpdater -from synapse.util import Clock from tests import unittest @@ -34,7 +31,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs: HomeServer): self.store = hs.get_datastore() self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -47,9 +44,9 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): ("POST", "/_synapse/admin/v1/background_updates/start_job"), ] ) - def test_requester_is_no_admin(self, method: str, url: str) -> None: + def test_requester_is_no_admin(self, method: str, url: str): """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ self.register_user("user", "pass", admin=False) @@ -65,7 +62,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_invalid_parameter(self) -> None: + def test_invalid_parameter(self): """ If parameters are invalid, an error is returned. """ @@ -93,7 +90,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) - def _register_bg_update(self) -> None: + def _register_bg_update(self): "Adds a bg update but doesn't start it" async def _fake_update(progress, batch_size) -> int: @@ -115,7 +112,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): ) ) - def test_status_empty(self) -> None: + def test_status_empty(self): """Test the status API works.""" channel = self.make_request( @@ -130,7 +127,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): channel.json_body, {"current_updates": {}, "enabled": True} ) - def test_status_bg_update(self) -> None: + def test_status_bg_update(self): """Test the status API works with a background update.""" # Create a new background update @@ -138,7 +135,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): self._register_bg_update() self.store.db_pool.updates.start_doing_background_updates() - self.reactor.pump([1.0, 1.0, 1.0]) + self.reactor.pump([1.0, 1.0]) channel = self.make_request( "GET", @@ -165,7 +162,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): }, ) - def test_enabled(self) -> None: + def test_enabled(self): """Test the enabled API works.""" # Create a new background update @@ -302,7 +299,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): ), ] ) - def test_start_backround_job(self, job_name: str, updates: Collection[str]) -> None: + def test_start_backround_job(self, job_name: str, updates: Collection[str]): """ Test that background updates add to database and be processed. @@ -344,7 +341,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): ) ) - def test_start_backround_job_twice(self) -> None: + def test_start_backround_job_twice(self): """Test that add a background update twice return an error.""" # add job to database diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index f7080bda87..a3679be205 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -11,18 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import urllib.parse -from http import HTTPStatus from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor - import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login -from synapse.server import HomeServer -from synapse.util import Clock from tests import unittest @@ -34,7 +30,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.handler = hs.get_device_handler() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -51,21 +47,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): ) @parameterized.expand(["GET", "PUT", "DELETE"]) - def test_no_auth(self, method: str) -> None: + def test_no_auth(self, method: str): """ Try to get a device of an user without authentication. """ channel = self.make_request(method, self.url, b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "PUT", "DELETE"]) - def test_requester_is_no_admin(self, method: str) -> None: + def test_requester_is_no_admin(self, method: str): """ If the user is not a server admin, an error is returned. """ @@ -75,17 +67,13 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "PUT", "DELETE"]) - def test_user_does_not_exist(self, method: str) -> None: + def test_user_does_not_exist(self, method: str): """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ url = ( "/_synapse/admin/v2/users/@unknown_person:test/devices/%s" @@ -98,13 +86,13 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @parameterized.expand(["GET", "PUT", "DELETE"]) - def test_user_is_not_local(self, method: str) -> None: + def test_user_is_not_local(self, method: str): """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ url = ( "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices/%s" @@ -117,12 +105,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - def test_unknown_device(self) -> None: + def test_unknown_device(self): """ - Tests that a lookup for a device that does not exist returns either HTTPStatus.NOT_FOUND or HTTPStatus.OK. + Tests that a lookup for a device that does not exist returns either 404 or 200. """ url = "/_synapse/admin/v2/users/%s/devices/unknown_device" % urllib.parse.quote( self.other_user @@ -134,7 +122,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) channel = self.make_request( @@ -143,7 +131,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) channel = self.make_request( "DELETE", @@ -151,10 +139,10 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - # Delete unknown device returns status HTTPStatus.OK - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + # Delete unknown device returns status 200 + self.assertEqual(200, channel.code, msg=channel.json_body) - def test_update_device_too_long_display_name(self) -> None: + def test_update_device_too_long_display_name(self): """ Update a device with a display name that is invalid (too long). """ @@ -179,7 +167,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): content=update, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"]) # Ensure the display name was not updated. @@ -189,12 +177,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) - def test_update_no_display_name(self) -> None: + def test_update_no_display_name(self): """ - Tests that a update for a device without JSON returns a HTTPStatus.OK + Tests that a update for a device without JSON returns a 200 """ # Set iniital display name. update = {"display_name": "new display"} @@ -210,7 +198,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Ensure the display name was not updated. channel = self.make_request( @@ -219,10 +207,10 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) - def test_update_display_name(self) -> None: + def test_update_display_name(self): """ Tests a normal successful update of display name """ @@ -234,7 +222,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): content={"display_name": "new displayname"}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check new display_name channel = self.make_request( @@ -243,10 +231,10 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new displayname", channel.json_body["display_name"]) - def test_get_device(self) -> None: + def test_get_device(self): """ Tests that a normal lookup for a device is successfully """ @@ -256,7 +244,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) # Check that all fields are available self.assertIn("user_id", channel.json_body) @@ -265,7 +253,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertIn("last_seen_ip", channel.json_body) self.assertIn("last_seen_ts", channel.json_body) - def test_delete_device(self) -> None: + def test_delete_device(self): """ Tests that a remove of a device is successfully """ @@ -281,7 +269,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Ensure that the number of devices is decreased res = self.get_success(self.handler.get_devices_by_user(self.other_user)) @@ -295,7 +283,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -305,20 +293,16 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): self.other_user ) - def test_no_auth(self) -> None: + def test_no_auth(self): """ Try to list devices of an user without authentication. """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self) -> None: + def test_requester_is_no_admin(self): """ If the user is not a server admin, an error is returned. """ @@ -330,16 +314,12 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_does_not_exist(self) -> None: + def test_user_does_not_exist(self): """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v2/users/@unknown_person:test/devices" channel = self.make_request( @@ -348,12 +328,12 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_user_is_not_local(self) -> None: + def test_user_is_not_local(self): """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices" @@ -363,10 +343,10 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - def test_user_has_no_devices(self) -> None: + def test_user_has_no_devices(self): """ Tests that a normal lookup for devices is successfully if user has no devices @@ -379,11 +359,11 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["devices"])) - def test_get_devices(self) -> None: + def test_get_devices(self): """ Tests that a normal lookup for devices is successfully """ @@ -399,7 +379,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_devices, channel.json_body["total"]) self.assertEqual(number_devices, len(channel.json_body["devices"])) self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"]) @@ -419,7 +399,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.handler = hs.get_device_handler() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -431,20 +411,16 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): self.other_user ) - def test_no_auth(self) -> None: + def test_no_auth(self): """ Try to delete devices of an user without authentication. """ channel = self.make_request("POST", self.url, b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self) -> None: + def test_requester_is_no_admin(self): """ If the user is not a server admin, an error is returned. """ @@ -456,16 +432,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_does_not_exist(self) -> None: + def test_user_does_not_exist(self): """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices" channel = self.make_request( @@ -474,12 +446,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_user_is_not_local(self) -> None: + def test_user_is_not_local(self): """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices" @@ -489,12 +461,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - def test_unknown_devices(self) -> None: + def test_unknown_devices(self): """ - Tests that a remove of a device that does not exist returns HTTPStatus.OK. + Tests that a remove of a device that does not exist returns 200. """ channel = self.make_request( "POST", @@ -503,10 +475,10 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): content={"devices": ["unknown_device1", "unknown_device2"]}, ) - # Delete unknown devices returns status HTTPStatus.OK - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + # Delete unknown devices returns status 200 + self.assertEqual(200, channel.code, msg=channel.json_body) - def test_delete_devices(self) -> None: + def test_delete_devices(self): """ Tests that a remove of devices is successfully """ @@ -533,7 +505,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): content={"devices": device_ids}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) res = self.get_success(self.handler.get_devices_by_user(self.other_user)) self.assertEqual(0, len(res)) diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index 4f89f8b534..e9ef89731f 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -11,17 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus -from typing import List -from twisted.test.proto_helpers import MemoryReactor +import json import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login, report_event, room -from synapse.server import HomeServer -from synapse.types import JsonDict -from synapse.util import Clock from tests import unittest @@ -34,7 +29,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): report_event.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -75,22 +70,18 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/event_reports" - def test_no_auth(self) -> None: + def test_no_auth(self): """ Try to get an event report without authentication. """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self) -> None: + def test_requester_is_no_admin(self): """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ channel = self.make_request( @@ -99,14 +90,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_default_success(self) -> None: + def test_default_success(self): """ Testing list of reported events """ @@ -117,13 +104,13 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["event_reports"]) - def test_limit(self) -> None: + def test_limit(self): """ Testing list of reported events with limit """ @@ -134,13 +121,13 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 5) self.assertEqual(channel.json_body["next_token"], 5) self._check_fields(channel.json_body["event_reports"]) - def test_from(self) -> None: + def test_from(self): """ Testing list of reported events with a defined starting point (from) """ @@ -151,13 +138,13 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 15) self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["event_reports"]) - def test_limit_and_from(self) -> None: + def test_limit_and_from(self): """ Testing list of reported events with a defined starting point and limit """ @@ -168,13 +155,13 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(len(channel.json_body["event_reports"]), 10) self._check_fields(channel.json_body["event_reports"]) - def test_filter_room(self) -> None: + def test_filter_room(self): """ Testing list of reported events with a filter of room """ @@ -185,7 +172,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["event_reports"]), 10) self.assertNotIn("next_token", channel.json_body) @@ -194,7 +181,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): for report in channel.json_body["event_reports"]: self.assertEqual(report["room_id"], self.room_id1) - def test_filter_user(self) -> None: + def test_filter_user(self): """ Testing list of reported events with a filter of user """ @@ -205,7 +192,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["event_reports"]), 10) self.assertNotIn("next_token", channel.json_body) @@ -214,7 +201,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): for report in channel.json_body["event_reports"]: self.assertEqual(report["user_id"], self.other_user) - def test_filter_user_and_room(self) -> None: + def test_filter_user_and_room(self): """ Testing list of reported events with a filter of user and room """ @@ -225,7 +212,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 5) self.assertEqual(len(channel.json_body["event_reports"]), 5) self.assertNotIn("next_token", channel.json_body) @@ -235,7 +222,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertEqual(report["user_id"], self.other_user) self.assertEqual(report["room_id"], self.room_id1) - def test_valid_search_order(self) -> None: + def test_valid_search_order(self): """ Testing search order. Order by timestamps. """ @@ -247,7 +234,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) report = 1 @@ -265,7 +252,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) report = 1 @@ -276,9 +263,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): ) report += 1 - def test_invalid_search_order(self) -> None: + def test_invalid_search_order(self): """ - Testing that a invalid search order returns a HTTPStatus.BAD_REQUEST + Testing that a invalid search order returns a 400 """ channel = self.make_request( @@ -287,17 +274,13 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual("Unknown direction: bar", channel.json_body["error"]) - def test_limit_is_negative(self) -> None: + def test_limit_is_negative(self): """ - Testing that a negative limit parameter returns a HTTPStatus.BAD_REQUEST + Testing that a negative limit parameter returns a 400 """ channel = self.make_request( @@ -306,16 +289,12 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - def test_from_is_negative(self) -> None: + def test_from_is_negative(self): """ - Testing that a negative from parameter returns a HTTPStatus.BAD_REQUEST + Testing that a negative from parameter returns a 400 """ channel = self.make_request( @@ -324,14 +303,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - def test_next_token(self) -> None: + def test_next_token(self): """ Testing that `next_token` appears at the right place """ @@ -344,7 +319,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) self.assertNotIn("next_token", channel.json_body) @@ -357,7 +332,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) self.assertNotIn("next_token", channel.json_body) @@ -370,7 +345,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 19) self.assertEqual(channel.json_body["next_token"], 19) @@ -384,12 +359,12 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 1) self.assertNotIn("next_token", channel.json_body) - def _create_event_and_report(self, room_id: str, user_tok: str) -> None: + def _create_event_and_report(self, room_id, user_tok): """Create and report events""" resp = self.helper.send(room_id, tok=user_tok) event_id = resp["event_id"] @@ -397,14 +372,12 @@ class EventReportsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", "rooms/%s/report/%s" % (room_id, event_id), - {"score": -100, "reason": "this makes me sad"}, + json.dumps({"score": -100, "reason": "this makes me sad"}), access_token=user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - def _create_event_and_report_without_parameters( - self, room_id: str, user_tok: str - ) -> None: + def _create_event_and_report_without_parameters(self, room_id, user_tok): """Create and report an event, but omit reason and score""" resp = self.helper.send(room_id, tok=user_tok) event_id = resp["event_id"] @@ -412,12 +385,12 @@ class EventReportsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", "rooms/%s/report/%s" % (room_id, event_id), - {}, + json.dumps({}), access_token=user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - def _check_fields(self, content: List[JsonDict]) -> None: + def _check_fields(self, content): """Checks that all attributes are present in an event report""" for c in content: self.assertIn("id", c) @@ -440,7 +413,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): report_event.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -460,22 +433,18 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): # first created event report gets `id`=2 self.url = "/_synapse/admin/v1/event_reports/2" - def test_no_auth(self) -> None: + def test_no_auth(self): """ Try to get event report without authentication. """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self) -> None: + def test_requester_is_no_admin(self): """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ channel = self.make_request( @@ -484,14 +453,10 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_default_success(self) -> None: + def test_default_success(self): """ Testing get a reported event """ @@ -502,12 +467,12 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self._check_fields(channel.json_body) - def test_invalid_report_id(self) -> None: + def test_invalid_report_id(self): """ - Testing that an invalid `report_id` returns a HTTPStatus.BAD_REQUEST. + Testing that an invalid `report_id` returns a 400. """ # `report_id` is negative @@ -517,11 +482,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "The report_id parameter must be a string representing a positive integer.", @@ -535,11 +496,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "The report_id parameter must be a string representing a positive integer.", @@ -553,20 +510,16 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "The report_id parameter must be a string representing a positive integer.", channel.json_body["error"], ) - def test_report_id_not_found(self) -> None: + def test_report_id_not_found(self): """ - Testing that a not existing `report_id` returns a HTTPStatus.NOT_FOUND. + Testing that a not existing `report_id` returns a 404. """ channel = self.make_request( @@ -575,15 +528,11 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.NOT_FOUND, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) self.assertEqual("Event report not found", channel.json_body["error"]) - def _create_event_and_report(self, room_id: str, user_tok: str) -> None: + def _create_event_and_report(self, room_id, user_tok): """Create and report events""" resp = self.helper.send(room_id, tok=user_tok) event_id = resp["event_id"] @@ -591,12 +540,12 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", "rooms/%s/report/%s" % (room_id, event_id), - {"score": -100, "reason": "this makes me sad"}, + json.dumps({"score": -100, "reason": "this makes me sad"}), access_token=user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - def _check_fields(self, content: JsonDict) -> None: + def _check_fields(self, content): """Checks that all attributes are present in a event report""" self.assertIn("id", content) self.assertIn("received_ts", content) diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py deleted file mode 100644 index 5188499ef2..0000000000 --- a/tests/rest/admin/test_federation.py +++ /dev/null @@ -1,456 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from http import HTTPStatus -from typing import List, Optional - -from parameterized import parameterized - -import synapse.rest.admin -from synapse.api.errors import Codes -from synapse.rest.client import login -from synapse.server import HomeServer -from synapse.types import JsonDict - -from tests import unittest - - -class FederationTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - ] - - def prepare(self, reactor, clock, hs: HomeServer): - self.store = hs.get_datastore() - self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - self.url = "/_synapse/admin/v1/federation/destinations" - - @parameterized.expand( - [ - ("/_synapse/admin/v1/federation/destinations",), - ("/_synapse/admin/v1/federation/destinations/dummy",), - ] - ) - def test_requester_is_no_admin(self, url: str): - """ - If the user is not a server admin, an error 403 is returned. - """ - - self.register_user("user", "pass", admin=False) - other_user_tok = self.login("user", "pass") - - channel = self.make_request( - "GET", - url, - content={}, - access_token=other_user_tok, - ) - - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - def test_invalid_parameter(self): - """ - If parameters are invalid, an error is returned. - """ - - # negative limit - channel = self.make_request( - "GET", - self.url + "?limit=-5", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - - # negative from - channel = self.make_request( - "GET", - self.url + "?from=-5", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - - # unkown order_by - channel = self.make_request( - "GET", - self.url + "?order_by=bar", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) - - # invalid search order - channel = self.make_request( - "GET", - self.url + "?dir=bar", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) - - # invalid destination - channel = self.make_request( - "GET", - self.url + "/dummy", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - - def test_limit(self): - """ - Testing list of destinations with limit - """ - - number_destinations = 20 - self._create_destinations(number_destinations) - - channel = self.make_request( - "GET", - self.url + "?limit=5", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual(channel.json_body["total"], number_destinations) - self.assertEqual(len(channel.json_body["destinations"]), 5) - self.assertEqual(channel.json_body["next_token"], "5") - self._check_fields(channel.json_body["destinations"]) - - def test_from(self): - """ - Testing list of destinations with a defined starting point (from) - """ - - number_destinations = 20 - self._create_destinations(number_destinations) - - channel = self.make_request( - "GET", - self.url + "?from=5", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual(channel.json_body["total"], number_destinations) - self.assertEqual(len(channel.json_body["destinations"]), 15) - self.assertNotIn("next_token", channel.json_body) - self._check_fields(channel.json_body["destinations"]) - - def test_limit_and_from(self): - """ - Testing list of destinations with a defined starting point and limit - """ - - number_destinations = 20 - self._create_destinations(number_destinations) - - channel = self.make_request( - "GET", - self.url + "?from=5&limit=10", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual(channel.json_body["total"], number_destinations) - self.assertEqual(channel.json_body["next_token"], "15") - self.assertEqual(len(channel.json_body["destinations"]), 10) - self._check_fields(channel.json_body["destinations"]) - - def test_next_token(self): - """ - Testing that `next_token` appears at the right place - """ - - number_destinations = 20 - self._create_destinations(number_destinations) - - # `next_token` does not appear - # Number of results is the number of entries - channel = self.make_request( - "GET", - self.url + "?limit=20", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual(channel.json_body["total"], number_destinations) - self.assertEqual(len(channel.json_body["destinations"]), number_destinations) - self.assertNotIn("next_token", channel.json_body) - - # `next_token` does not appear - # Number of max results is larger than the number of entries - channel = self.make_request( - "GET", - self.url + "?limit=21", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual(channel.json_body["total"], number_destinations) - self.assertEqual(len(channel.json_body["destinations"]), number_destinations) - self.assertNotIn("next_token", channel.json_body) - - # `next_token` does appear - # Number of max results is smaller than the number of entries - channel = self.make_request( - "GET", - self.url + "?limit=19", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual(channel.json_body["total"], number_destinations) - self.assertEqual(len(channel.json_body["destinations"]), 19) - self.assertEqual(channel.json_body["next_token"], "19") - - # Check - # Set `from` to value of `next_token` for request remaining entries - # `next_token` does not appear - channel = self.make_request( - "GET", - self.url + "?from=19", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual(channel.json_body["total"], number_destinations) - self.assertEqual(len(channel.json_body["destinations"]), 1) - self.assertNotIn("next_token", channel.json_body) - - def test_list_all_destinations(self): - """ - List all destinations. - """ - number_destinations = 5 - self._create_destinations(number_destinations) - - channel = self.make_request( - "GET", - self.url, - {}, - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual(number_destinations, len(channel.json_body["destinations"])) - self.assertEqual(number_destinations, channel.json_body["total"]) - - # Check that all fields are available - self._check_fields(channel.json_body["destinations"]) - - def test_order_by(self): - """ - Testing order list with parameter `order_by` - """ - - def _order_test( - expected_destination_list: List[str], - order_by: Optional[str], - dir: Optional[str] = None, - ): - """Request the list of destinations in a certain order. - Assert that order is what we expect - - Args: - expected_destination_list: The list of user_id in the order - we expect to get back from the server - order_by: The type of ordering to give the server - dir: The direction of ordering to give the server - """ - - url = f"{self.url}?" - if order_by is not None: - url += f"order_by={order_by}&" - if dir is not None and dir in ("b", "f"): - url += f"dir={dir}" - channel = self.make_request( - "GET", - url, - access_token=self.admin_user_tok, - ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual(channel.json_body["total"], len(expected_destination_list)) - - returned_order = [ - row["destination"] for row in channel.json_body["destinations"] - ] - self.assertEqual(expected_destination_list, returned_order) - self._check_fields(channel.json_body["destinations"]) - - # create destinations - dest = [ - ("sub-a.example.com", 100, 300, 200, 300), - ("sub-b.example.com", 200, 200, 100, 100), - ("sub-c.example.com", 300, 100, 300, 200), - ] - for ( - destination, - failure_ts, - retry_last_ts, - retry_interval, - last_successful_stream_ordering, - ) in dest: - self.get_success( - self.store.set_destination_retry_timings( - destination, failure_ts, retry_last_ts, retry_interval - ) - ) - self.get_success( - self.store.set_destination_last_successful_stream_ordering( - destination, last_successful_stream_ordering - ) - ) - - # order by default (destination) - _order_test([dest[0][0], dest[1][0], dest[2][0]], None) - _order_test([dest[0][0], dest[1][0], dest[2][0]], None, "f") - _order_test([dest[2][0], dest[1][0], dest[0][0]], None, "b") - - # order by destination - _order_test([dest[0][0], dest[1][0], dest[2][0]], "destination") - _order_test([dest[0][0], dest[1][0], dest[2][0]], "destination", "f") - _order_test([dest[2][0], dest[1][0], dest[0][0]], "destination", "b") - - # order by failure_ts - _order_test([dest[0][0], dest[1][0], dest[2][0]], "failure_ts") - _order_test([dest[0][0], dest[1][0], dest[2][0]], "failure_ts", "f") - _order_test([dest[2][0], dest[1][0], dest[0][0]], "failure_ts", "b") - - # order by retry_last_ts - _order_test([dest[2][0], dest[1][0], dest[0][0]], "retry_last_ts") - _order_test([dest[2][0], dest[1][0], dest[0][0]], "retry_last_ts", "f") - _order_test([dest[0][0], dest[1][0], dest[2][0]], "retry_last_ts", "b") - - # order by retry_interval - _order_test([dest[1][0], dest[0][0], dest[2][0]], "retry_interval") - _order_test([dest[1][0], dest[0][0], dest[2][0]], "retry_interval", "f") - _order_test([dest[2][0], dest[0][0], dest[1][0]], "retry_interval", "b") - - # order by last_successful_stream_ordering - _order_test( - [dest[1][0], dest[2][0], dest[0][0]], "last_successful_stream_ordering" - ) - _order_test( - [dest[1][0], dest[2][0], dest[0][0]], "last_successful_stream_ordering", "f" - ) - _order_test( - [dest[0][0], dest[2][0], dest[1][0]], "last_successful_stream_ordering", "b" - ) - - def test_search_term(self): - """Test that searching for a destination works correctly""" - - def _search_test( - expected_destination: Optional[str], - search_term: str, - ): - """Search for a destination and check that the returned destinationis a match - - Args: - expected_destination: The room_id expected to be returned by the API. - Set to None to expect zero results for the search - search_term: The term to search for room names with - """ - url = f"{self.url}?destination={search_term}" - channel = self.make_request( - "GET", - url.encode("ascii"), - access_token=self.admin_user_tok, - ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - - # Check that destinations were returned - self.assertTrue("destinations" in channel.json_body) - self._check_fields(channel.json_body["destinations"]) - destinations = channel.json_body["destinations"] - - # Check that the expected number of destinations were returned - expected_destination_count = 1 if expected_destination else 0 - self.assertEqual(len(destinations), expected_destination_count) - self.assertEqual(channel.json_body["total"], expected_destination_count) - - if expected_destination: - # Check that the first returned destination is correct - self.assertEqual(expected_destination, destinations[0]["destination"]) - - number_destinations = 3 - self._create_destinations(number_destinations) - - # Test searching - _search_test("sub0.example.com", "0") - _search_test("sub0.example.com", "sub0") - - _search_test("sub1.example.com", "1") - _search_test("sub1.example.com", "1.") - - # Test case insensitive - _search_test("sub0.example.com", "SUB0") - - _search_test(None, "foo") - _search_test(None, "bar") - - def test_get_single_destination(self): - """ - Get one specific destinations. - """ - self._create_destinations(5) - - channel = self.make_request( - "GET", - self.url + "/sub0.example.com", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual("sub0.example.com", channel.json_body["destination"]) - - # Check that all fields are available - # convert channel.json_body into a List - self._check_fields([channel.json_body]) - - def _create_destinations(self, number_destinations: int): - """Create a number of destinations - - Args: - number_destinations: Number of destinations to be created - """ - for i in range(0, number_destinations): - dest = f"sub{i}.example.com" - self.get_success(self.store.set_destination_retry_timings(dest, 50, 50, 50)) - self.get_success( - self.store.set_destination_last_successful_stream_ordering(dest, 100) - ) - - def _check_fields(self, content: List[JsonDict]): - """Checks that the expected destination attributes are present in content - - Args: - content: List that is checked for content - """ - for c in content: - self.assertIn("destination", c) - self.assertIn("retry_last_ts", c) - self.assertIn("retry_interval", c) - self.assertIn("failure_ts", c) - self.assertIn("last_successful_stream_ordering", c) diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 81e578fd26..db0e78c039 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -12,19 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import json import os -from http import HTTPStatus from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor - import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login, profile, room from synapse.rest.media.v1.filepath import MediaFilePaths -from synapse.server import HomeServer -from synapse.util import Clock from tests import unittest from tests.server import FakeSite, make_request @@ -42,7 +39,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname @@ -51,7 +48,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): self.filepaths = MediaFilePaths(hs.config.media.media_store_path) - def test_no_auth(self) -> None: + def test_no_auth(self): """ Try to delete media without authentication. """ @@ -59,14 +56,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): channel = self.make_request("DELETE", url, b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self) -> None: + def test_requester_is_no_admin(self): """ If the user is not a server admin, an error is returned. """ @@ -81,16 +74,12 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_media_does_not_exist(self) -> None: + def test_media_does_not_exist(self): """ - Tests that a lookup for a media that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a media that does not exist returns a 404 """ url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") @@ -100,12 +89,12 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_media_is_not_local(self) -> None: + def test_media_is_not_local(self): """ - Tests that a lookup for a media that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a media that is not a local returns a 400 """ url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345") @@ -115,10 +104,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only delete local media", channel.json_body["error"]) - def test_delete_media(self) -> None: + def test_delete_media(self): """ Tests that delete a media is successfully """ @@ -128,10 +117,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): # Upload some media into the room response = self.helper.upload_media( - upload_resource, - SMALL_PNG, - tok=self.admin_user_tok, - expect_code=HTTPStatus.OK, + upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200 ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -151,11 +137,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): # Should be successful self.assertEqual( - HTTPStatus.OK, + 200, channel.code, msg=( - "Expected to receive a HTTPStatus.OK on accessing media: %s" - % server_and_media_id + "Expected to receive a 200 on accessing media: %s" % server_and_media_id ), ) @@ -172,7 +157,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( media_id, @@ -189,10 +174,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) self.assertEqual( - HTTPStatus.NOT_FOUND, + 404, channel.code, msg=( - "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s" + "Expected to receive a 404 on accessing deleted media: %s" % server_and_media_id ), ) @@ -211,7 +196,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname @@ -224,21 +209,17 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): # Move clock up to somewhat realistic time self.reactor.advance(1000000000) - def test_no_auth(self) -> None: + def test_no_auth(self): """ Try to delete media without authentication. """ channel = self.make_request("POST", self.url, b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self) -> None: + def test_requester_is_no_admin(self): """ If the user is not a server admin, an error is returned. """ @@ -251,16 +232,12 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_media_is_not_local(self) -> None: + def test_media_is_not_local(self): """ - Tests that a lookup for media that is not local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for media that is not local returns a 400 """ url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain" @@ -270,10 +247,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only delete local media", channel.json_body["error"]) - def test_missing_parameter(self) -> None: + def test_missing_parameter(self): """ If the parameter `before_ts` is missing, an error is returned. """ @@ -283,17 +260,13 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) self.assertEqual( "Missing integer query parameter 'before_ts'", channel.json_body["error"] ) - def test_invalid_parameter(self) -> None: + def test_invalid_parameter(self): """ If parameters are invalid, an error is returned. """ @@ -303,11 +276,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter before_ts must be a positive integer.", @@ -320,11 +289,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter before_ts you provided is from the year 1970. " @@ -338,11 +303,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter size_gt must be a string representing a positive integer.", @@ -355,18 +316,14 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual( "Boolean query parameter 'keep_profiles' must be one of ['true', 'false']", channel.json_body["error"], ) - def test_delete_media_never_accessed(self) -> None: + def test_delete_media_never_accessed(self): """ Tests that media deleted if it is older than `before_ts` and never accessed `last_access_ts` is `NULL` and `created_ts` < `before_ts` @@ -388,7 +345,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( media_id, @@ -397,7 +354,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id, False) - def test_keep_media_by_date(self) -> None: + def test_keep_media_by_date(self): """ Tests that media is not deleted if it is newer than `before_ts` """ @@ -413,7 +370,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self._access_media(server_and_media_id) @@ -425,7 +382,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( server_and_media_id.split("/")[1], @@ -434,7 +391,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id, False) - def test_keep_media_by_size(self) -> None: + def test_keep_media_by_size(self): """ Tests that media is not deleted if its size is smaller than or equal to `size_gt` @@ -449,7 +406,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&size_gt=67", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self._access_media(server_and_media_id) @@ -460,7 +417,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&size_gt=66", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( server_and_media_id.split("/")[1], @@ -469,7 +426,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id, False) - def test_keep_media_by_user_avatar(self) -> None: + def test_keep_media_by_user_avatar(self): """ Tests that we do not delete media if is used as a user avatar Tests parameter `keep_profiles` @@ -482,10 +439,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", "/profile/%s/avatar_url" % (self.admin_user,), - content={"avatar_url": "mxc://%s" % (server_and_media_id,)}, + content=json.dumps({"avatar_url": "mxc://%s" % (server_and_media_id,)}), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) now_ms = self.clock.time_msec() channel = self.make_request( @@ -493,7 +450,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self._access_media(server_and_media_id) @@ -504,7 +461,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( server_and_media_id.split("/")[1], @@ -513,7 +470,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id, False) - def test_keep_media_by_room_avatar(self) -> None: + def test_keep_media_by_room_avatar(self): """ Tests that we do not delete media if it is used as a room avatar Tests parameter `keep_profiles` @@ -527,10 +484,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", "/rooms/%s/state/m.room.avatar" % (room_id,), - content={"url": "mxc://%s" % (server_and_media_id,)}, + content=json.dumps({"url": "mxc://%s" % (server_and_media_id,)}), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) now_ms = self.clock.time_msec() channel = self.make_request( @@ -538,7 +495,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self._access_media(server_and_media_id) @@ -549,7 +506,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( server_and_media_id.split("/")[1], @@ -558,7 +515,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id, False) - def _create_media(self) -> str: + def _create_media(self): """ Create a media and return media_id and server_and_media_id """ @@ -566,10 +523,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): # Upload some media into the room response = self.helper.upload_media( - upload_resource, - SMALL_PNG, - tok=self.admin_user_tok, - expect_code=HTTPStatus.OK, + upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200 ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -580,7 +534,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): return server_and_media_id - def _access_media(self, server_and_media_id, expect_success=True) -> None: + def _access_media(self, server_and_media_id, expect_success=True): """ Try to access a media and check the result """ @@ -600,10 +554,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): if expect_success: self.assertEqual( - HTTPStatus.OK, + 200, channel.code, msg=( - "Expected to receive a HTTPStatus.OK on accessing media: %s" + "Expected to receive a 200 on accessing media: %s" % server_and_media_id ), ) @@ -611,10 +565,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertTrue(os.path.exists(local_path)) else: self.assertEqual( - HTTPStatus.NOT_FOUND, + 404, channel.code, msg=( - "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s" + "Expected to receive a 404 on accessing deleted media: %s" % (server_and_media_id) ), ) @@ -630,7 +584,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): media_repo = hs.get_media_repository_resource() self.store = hs.get_datastore() self.server_name = hs.hostname @@ -643,10 +597,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): # Upload some media into the room response = self.helper.upload_media( - upload_resource, - SMALL_PNG, - tok=self.admin_user_tok, - expect_code=HTTPStatus.OK, + upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200 ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -655,7 +606,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/media/%s/%s/%s" @parameterized.expand(["quarantine", "unquarantine"]) - def test_no_auth(self, action: str) -> None: + def test_no_auth(self, action: str): """ Try to protect media without authentication. """ @@ -666,15 +617,11 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): b"{}", ) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["quarantine", "unquarantine"]) - def test_requester_is_no_admin(self, action: str) -> None: + def test_requester_is_no_admin(self, action: str): """ If the user is not a server admin, an error is returned. """ @@ -687,14 +634,10 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_quarantine_media(self) -> None: + def test_quarantine_media(self): """ Tests that quarantining and remove from quarantine a media is successfully """ @@ -709,7 +652,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) @@ -722,13 +665,13 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) self.assertFalse(media_info["quarantined_by"]) - def test_quarantine_protected_media(self) -> None: + def test_quarantine_protected_media(self): """ Tests that quarantining from protected media fails """ @@ -747,7 +690,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) # verify that is not in quarantine @@ -763,7 +706,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): media_repo = hs.get_media_repository_resource() self.store = hs.get_datastore() @@ -775,10 +718,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): # Upload some media into the room response = self.helper.upload_media( - upload_resource, - SMALL_PNG, - tok=self.admin_user_tok, - expect_code=HTTPStatus.OK, + upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200 ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -787,22 +727,18 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/media/%s/%s" @parameterized.expand(["protect", "unprotect"]) - def test_no_auth(self, action: str) -> None: + def test_no_auth(self, action: str): """ Try to protect media without authentication. """ channel = self.make_request("POST", self.url % (action, self.media_id), b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["protect", "unprotect"]) - def test_requester_is_no_admin(self, action: str) -> None: + def test_requester_is_no_admin(self, action: str): """ If the user is not a server admin, an error is returned. """ @@ -815,14 +751,10 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_protect_media(self) -> None: + def test_protect_media(self): """ Tests that protect and unprotect a media is successfully """ @@ -837,7 +769,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) @@ -850,7 +782,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) @@ -867,7 +799,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname @@ -877,21 +809,17 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): self.filepaths = MediaFilePaths(hs.config.media.media_store_path) self.url = "/_synapse/admin/v1/purge_media_cache" - def test_no_auth(self) -> None: + def test_no_auth(self): """ Try to delete media without authentication. """ channel = self.make_request("POST", self.url, b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_not_admin(self) -> None: + def test_requester_is_not_admin(self): """ If the user is not a server admin, an error is returned. """ @@ -904,14 +832,10 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_invalid_parameter(self) -> None: + def test_invalid_parameter(self): """ If parameters are invalid, an error is returned. """ @@ -921,11 +845,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter before_ts must be a positive integer.", @@ -938,11 +858,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter before_ts you provided is from the year 1970. " diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py index 350a62dda6..9bac423ae0 100644 --- a/tests/rest/admin/test_registration_tokens.py +++ b/tests/rest/admin/test_registration_tokens.py @@ -11,17 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import random import string -from http import HTTPStatus - -from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login -from synapse.server import HomeServer -from synapse.util import Clock from tests import unittest @@ -32,7 +28,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -42,7 +38,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/registration_tokens" - def _new_token(self, **kwargs) -> str: + def _new_token(self, **kwargs): """Helper function to create a token.""" token = kwargs.get( "token", @@ -64,17 +60,13 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): # CREATION - def test_create_no_auth(self) -> None: + def test_create_no_auth(self): """Try to create a token without authentication.""" channel = self.make_request("POST", self.url + "/new", {}) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_create_requester_not_admin(self) -> None: + def test_create_requester_not_admin(self): """Try to create a token while not an admin.""" channel = self.make_request( "POST", @@ -82,14 +74,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_create_using_defaults(self) -> None: + def test_create_using_defaults(self): """Create a token using all the defaults.""" channel = self.make_request( "POST", @@ -98,14 +86,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(len(channel.json_body["token"]), 16) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) self.assertEqual(channel.json_body["pending"], 0) self.assertEqual(channel.json_body["completed"], 0) - def test_create_specifying_fields(self) -> None: + def test_create_specifying_fields(self): """Create a token specifying the value of all fields.""" # As many of the allowed characters as possible with length <= 64 token = "adefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._~-" @@ -122,14 +110,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["token"], token) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"]) self.assertEqual(channel.json_body["pending"], 0) self.assertEqual(channel.json_body["completed"], 0) - def test_create_with_null_value(self) -> None: + def test_create_with_null_value(self): """Create a token specifying unlimited uses and no expiry.""" data = { "uses_allowed": None, @@ -143,14 +131,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(len(channel.json_body["token"]), 16) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) self.assertEqual(channel.json_body["pending"], 0) self.assertEqual(channel.json_body["completed"], 0) - def test_create_token_too_long(self) -> None: + def test_create_token_too_long(self): """Check token longer than 64 chars is invalid.""" data = {"token": "a" * 65} @@ -161,14 +149,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_token_invalid_chars(self) -> None: + def test_create_token_invalid_chars(self): """Check you can't create token with invalid characters.""" data = { "token": "abc/def", @@ -181,14 +165,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_token_already_exists(self) -> None: + def test_create_token_already_exists(self): """Check you can't create token that already exists.""" data = { "token": "abcd", @@ -200,7 +180,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): data, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel1.code, msg=channel1.json_body) + self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"]) channel2 = self.make_request( "POST", @@ -208,10 +188,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): data, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel2.code, msg=channel2.json_body) + self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"]) self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_unable_to_generate_token(self) -> None: + def test_create_unable_to_generate_token(self): """Check right error is raised when server can't generate unique token.""" # Create all possible single character tokens tokens = [] @@ -240,9 +220,9 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 1}, access_token=self.admin_user_tok, ) - self.assertEqual(500, channel.code, msg=channel.json_body) + self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"]) - def test_create_uses_allowed(self) -> None: + def test_create_uses_allowed(self): """Check you can only create a token with good values for uses_allowed.""" # Should work with 0 (token is invalid from the start) channel = self.make_request( @@ -251,7 +231,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 0}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["uses_allowed"], 0) # Should fail with negative integer @@ -261,11 +241,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": -5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with float @@ -275,14 +251,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 1.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_expiry_time(self) -> None: + def test_create_expiry_time(self): """Check you can't create a token with an invalid expiry_time.""" # Should fail with a time in the past channel = self.make_request( @@ -291,11 +263,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": self.clock.time_msec() - 10000}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with float @@ -305,14 +273,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": self.clock.time_msec() + 1000000.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_length(self) -> None: + def test_create_length(self): """Check you can only generate a token with a valid length.""" # Should work with 64 channel = self.make_request( @@ -321,7 +285,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 64}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(len(channel.json_body["token"]), 64) # Should fail with 0 @@ -331,11 +295,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 0}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with a negative integer @@ -345,11 +305,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": -5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with a float @@ -359,11 +315,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 8.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with 65 @@ -373,30 +325,22 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 65}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # UPDATING - def test_update_no_auth(self) -> None: + def test_update_no_auth(self): """Try to update a token without authentication.""" channel = self.make_request( "PUT", self.url + "/1234", # Token doesn't exist but that doesn't matter {}, ) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_update_requester_not_admin(self) -> None: + def test_update_requester_not_admin(self): """Try to update a token while not an admin.""" channel = self.make_request( "PUT", @@ -404,14 +348,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_update_non_existent(self) -> None: + def test_update_non_existent(self): """Try to update a token that doesn't exist.""" channel = self.make_request( "PUT", @@ -420,14 +360,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.NOT_FOUND, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_update_uses_allowed(self) -> None: + def test_update_uses_allowed(self): """Test updating just uses_allowed.""" # Create new token using default values token = self._new_token() @@ -439,7 +375,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 1}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertIsNone(channel.json_body["expiry_time"]) @@ -450,7 +386,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 0}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["uses_allowed"], 0) self.assertIsNone(channel.json_body["expiry_time"]) @@ -461,7 +397,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": None}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) @@ -472,11 +408,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 1.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with a negative integer @@ -486,14 +418,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": -5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_update_expiry_time(self) -> None: + def test_update_expiry_time(self): """Test updating just expiry_time.""" # Create new token using default values token = self._new_token() @@ -506,7 +434,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": new_expiry_time}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) self.assertIsNone(channel.json_body["uses_allowed"]) @@ -517,7 +445,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": None}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertIsNone(channel.json_body["expiry_time"]) self.assertIsNone(channel.json_body["uses_allowed"]) @@ -529,11 +457,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": past_time}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail a float @@ -543,14 +467,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": new_expiry_time + 0.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_update_both(self) -> None: + def test_update_both(self): """Test updating both uses_allowed and expiry_time.""" # Create new token using default values token = self._new_token() @@ -568,11 +488,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) - def test_update_invalid_type(self) -> None: + def test_update_invalid_type(self): """Test using invalid types doesn't work.""" # Create new token using default values token = self._new_token() @@ -589,30 +509,22 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # DELETING - def test_delete_no_auth(self) -> None: + def test_delete_no_auth(self): """Try to delete a token without authentication.""" channel = self.make_request( "DELETE", self.url + "/1234", # Token doesn't exist but that doesn't matter {}, ) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_delete_requester_not_admin(self) -> None: + def test_delete_requester_not_admin(self): """Try to delete a token while not an admin.""" channel = self.make_request( "DELETE", @@ -620,14 +532,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_delete_non_existent(self) -> None: + def test_delete_non_existent(self): """Try to delete a token that doesn't exist.""" channel = self.make_request( "DELETE", @@ -636,14 +544,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.NOT_FOUND, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_delete(self) -> None: + def test_delete(self): """Test deleting a token.""" # Create new token using default values token = self._new_token() @@ -655,25 +559,21 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # GETTING ONE - def test_get_no_auth(self) -> None: + def test_get_no_auth(self): """Try to get a token without authentication.""" channel = self.make_request( "GET", self.url + "/1234", # Token doesn't exist but that doesn't matter {}, ) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_get_requester_not_admin(self) -> None: + def test_get_requester_not_admin(self): """Try to get a token while not an admin.""" channel = self.make_request( "GET", @@ -681,14 +581,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_get_non_existent(self) -> None: + def test_get_non_existent(self): """Try to get a token that doesn't exist.""" channel = self.make_request( "GET", @@ -697,14 +593,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.NOT_FOUND, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_get(self) -> None: + def test_get(self): """Test getting a token.""" # Create new token using default values token = self._new_token() @@ -716,7 +608,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["token"], token) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) @@ -725,17 +617,13 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): # LISTING - def test_list_no_auth(self) -> None: + def test_list_no_auth(self): """Try to list tokens without authentication.""" channel = self.make_request("GET", self.url, {}) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_list_requester_not_admin(self) -> None: + def test_list_requester_not_admin(self): """Try to list tokens while not an admin.""" channel = self.make_request( "GET", @@ -743,14 +631,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_list_all(self) -> None: + def test_list_all(self): """Test listing all tokens.""" # Create new token using default values token = self._new_token() @@ -762,7 +646,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(len(channel.json_body["registration_tokens"]), 1) token_info = channel.json_body["registration_tokens"][0] self.assertEqual(token_info["token"], token) @@ -771,7 +655,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.assertEqual(token_info["pending"], 0) self.assertEqual(token_info["completed"], 0) - def test_list_invalid_query_parameter(self) -> None: + def test_list_invalid_query_parameter(self): """Test with `valid` query parameter not `true` or `false`.""" channel = self.make_request( "GET", @@ -780,13 +664,9 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - def _test_list_query_parameter(self, valid: str) -> None: + def _test_list_query_parameter(self, valid: str): """Helper used to test both valid=true and valid=false.""" # Create 2 valid and 2 invalid tokens. now = self.hs.get_clock().time_msec() @@ -816,17 +696,17 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(len(channel.json_body["registration_tokens"]), 2) token_info_1 = channel.json_body["registration_tokens"][0] token_info_2 = channel.json_body["registration_tokens"][1] self.assertIn(token_info_1["token"], tokens) self.assertIn(token_info_2["token"], tokens) - def test_list_valid(self) -> None: + def test_list_valid(self): """Test listing just valid tokens.""" self._test_list_query_parameter(valid="true") - def test_list_invalid(self) -> None: + def test_list_invalid(self): """Test listing just invalid tokens.""" self._test_list_query_parameter(valid="false") diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 22f9aa6234..07077aff78 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import json import urllib.parse from http import HTTPStatus from typing import List, Optional @@ -18,15 +20,11 @@ from unittest.mock import Mock from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor - import synapse.rest.admin from synapse.api.constants import EventTypes, Membership from synapse.api.errors import Codes from synapse.handlers.pagination import PaginationHandler from synapse.rest.client import directory, events, login, room -from synapse.server import HomeServer -from synapse.util import Clock from tests import unittest @@ -42,7 +40,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): room.register_deprecated_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.event_creation_handler = hs.get_event_creation_handler() hs.config.consent.user_consent_version = "1" @@ -68,7 +66,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): def test_requester_is_no_admin(self): """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ channel = self.make_request( @@ -78,12 +76,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_room_does_not_exist(self): """ - Check that unknown rooms/server return 200 + Check that unknown rooms/server return error 404. """ url = "/_synapse/admin/v1/rooms/%s" % "!unknown:test" @@ -94,11 +92,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_room_is_not_valid(self): """ - Check that invalid room names, return an error HTTPStatus.BAD_REQUEST. + Check that invalid room names, return an error 400. """ url = "/_synapse/admin/v1/rooms/%s" % "invalidroom" @@ -109,7 +108,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "invalidroom is not a legal room ID", channel.json_body["error"], @@ -119,15 +118,16 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ Tests that the user ID must be from local server but it does not have to exist. """ + body = json.dumps({"new_room_user_id": "@unknown:test"}) channel = self.make_request( "DELETE", self.url, - content={"new_room_user_id": "@unknown:test"}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("new_room_id", channel.json_body) self.assertIn("kicked_users", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -137,15 +137,16 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ Check that only local users can create new room to move members. """ + body = json.dumps({"new_room_user_id": "@not:exist.bla"}) channel = self.make_request( "DELETE", self.url, - content={"new_room_user_id": "@not:exist.bla"}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "User must be our own: @not:exist.bla", channel.json_body["error"], @@ -155,30 +156,32 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ If parameter `block` is not boolean, return an error """ + body = json.dumps({"block": "NotBool"}) channel = self.make_request( "DELETE", self.url, - content={"block": "NotBool"}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_purge_is_not_bool(self): """ If parameter `purge` is not boolean, return an error """ + body = json.dumps({"purge": "NotBool"}) channel = self.make_request( "DELETE", self.url, - content={"purge": "NotBool"}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_purge_room_and_block(self): @@ -195,14 +198,16 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Assert one user in room self._is_member(room_id=self.room_id, user_id=self.other_user) + body = json.dumps({"block": True, "purge": True}) + channel = self.make_request( "DELETE", self.url.encode("ascii"), - content={"block": True, "purge": True}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -226,14 +231,16 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Assert one user in room self._is_member(room_id=self.room_id, user_id=self.other_user) + body = json.dumps({"block": False, "purge": True}) + channel = self.make_request( "DELETE", self.url.encode("ascii"), - content={"block": False, "purge": True}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -258,14 +265,16 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Assert one user in room self._is_member(room_id=self.room_id, user_id=self.other_user) + body = json.dumps({"block": True, "purge": False}) + channel = self.make_request( "DELETE", self.url.encode("ascii"), - content={"block": True, "purge": False}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -296,7 +305,9 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): ) # The room is now blocked. - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual( + HTTPStatus.OK, int(channel.result["code"]), msg=channel.result["body"] + ) self._is_blocked(room_id) def test_shutdown_room_consent(self): @@ -316,10 +327,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Assert that the user is getting consent error self.helper.send( - self.room_id, - body="foo", - tok=self.other_user_tok, - expect_code=HTTPStatus.FORBIDDEN, + self.room_id, body="foo", tok=self.other_user_tok, expect_code=403 ) # Test that room is not purged @@ -333,11 +341,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - {"new_room_user_id": self.admin_user}, + json.dumps({"new_room_user_id": self.admin_user}), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("new_room_id", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -363,10 +371,10 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", url.encode("ascii"), - {"history_visibility": "world_readable"}, + json.dumps({"history_visibility": "world_readable"}), access_token=self.other_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Test that room is not purged with self.assertRaises(AssertionError): @@ -379,11 +387,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - {"new_room_user_id": self.admin_user}, + json.dumps({"new_room_user_id": self.admin_user}), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("new_room_id", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -398,7 +406,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): self._has_no_members(self.room_id) # Assert we can no longer peek into the room - self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN) + self._assert_peek(self.room_id, expect_code=403) def _is_blocked(self, room_id, expect=True): """Assert that the room is blocked or not""" @@ -457,7 +465,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): room.register_deprecated_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.event_creation_handler = hs.get_event_creation_handler() hs.config.consent.user_consent_version = "1" @@ -494,7 +502,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): ) def test_requester_is_no_admin(self, method: str, url: str): """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ channel = self.make_request( @@ -507,36 +515,27 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_room_does_not_exist(self): + @parameterized.expand( + [ + ("DELETE", "/_synapse/admin/v2/rooms/%s"), + ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"), + ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"), + ] + ) + def test_room_does_not_exist(self, method: str, url: str): """ - Check that unknown rooms/server return 200 - - This is important, as it allows incomplete vestiges of rooms to be cleared up - even if the create event/etc is missing. + Check that unknown rooms/server return error 404. """ - room_id = "!unknown:test" - channel = self.make_request( - "DELETE", - f"/_synapse/admin/v2/rooms/{room_id}", - content={}, - access_token=self.admin_user_tok, - ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertIn("delete_id", channel.json_body) - delete_id = channel.json_body["delete_id"] - - # get status channel = self.make_request( - "GET", - f"/_synapse/admin/v2/rooms/{room_id}/delete_status", + method, + url % "!unknown:test", + content={}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual(1, len(channel.json_body["results"])) - self.assertEqual("complete", channel.json_body["results"][0]["status"]) - self.assertEqual(delete_id, channel.json_body["results"][0]["delete_id"]) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @parameterized.expand( [ @@ -546,7 +545,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): ) def test_room_is_not_valid(self, method: str, url: str): """ - Check that invalid room names, return an error HTTPStatus.BAD_REQUEST. + Check that invalid room names, return an error 400. """ channel = self.make_request( @@ -855,10 +854,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): # Assert that the user is getting consent error self.helper.send( - self.room_id, - body="foo", - tok=self.other_user_tok, - expect_code=HTTPStatus.FORBIDDEN, + self.room_id, body="foo", tok=self.other_user_tok, expect_code=403 ) # Test that room is not purged @@ -955,7 +951,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): self._has_no_members(self.room_id) # Assert we can no longer peek into the room - self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN) + self._assert_peek(self.room_id, expect_code=403) def _is_blocked(self, room_id: str, expect: bool = True) -> None: """Assert that the room is blocked or not""" @@ -1073,12 +1069,12 @@ class RoomTestCase(unittest.HomeserverTestCase): directory.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): # Create user self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - def test_list_rooms(self) -> None: + def test_list_rooms(self): """Test that we can list rooms""" # Create 3 test rooms total_rooms = 3 @@ -1098,7 +1094,7 @@ class RoomTestCase(unittest.HomeserverTestCase): ) # Check request completed successfully - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check that response json body contains a "rooms" key self.assertTrue( @@ -1142,7 +1138,7 @@ class RoomTestCase(unittest.HomeserverTestCase): # We shouldn't receive a next token here as there's no further rooms to show self.assertNotIn("next_batch", channel.json_body) - def test_list_rooms_pagination(self) -> None: + def test_list_rooms_pagination(self): """Test that we can get a full list of rooms through pagination""" # Create 5 test rooms total_rooms = 5 @@ -1182,7 +1178,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertTrue("rooms" in channel.json_body) for r in channel.json_body["rooms"]: @@ -1222,9 +1218,9 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) - def test_correct_room_attributes(self) -> None: + def test_correct_room_attributes(self): """Test the correct attributes for a room are returned""" # Create a test room room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1245,7 +1241,7 @@ class RoomTestCase(unittest.HomeserverTestCase): {"room_id": room_id}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Set this new alias as the canonical alias for this room self.helper.send_state( @@ -1277,7 +1273,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check that rooms were returned self.assertTrue("rooms" in channel.json_body) @@ -1305,7 +1301,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(test_room_name, r["name"]) self.assertEqual(test_alias, r["canonical_alias"]) - def test_room_list_sort_order(self) -> None: + def test_room_list_sort_order(self): """Test room list sort ordering. alphabetical name versus number of members, reversing the order, etc. """ @@ -1314,7 +1310,7 @@ class RoomTestCase(unittest.HomeserverTestCase): order_type: str, expected_room_list: List[str], reverse: bool = False, - ) -> None: + ): """Request the list of rooms in a certain order. Assert that order is what we expect @@ -1332,7 +1328,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check that rooms were returned self.assertTrue("rooms" in channel.json_body) @@ -1443,7 +1439,7 @@ class RoomTestCase(unittest.HomeserverTestCase): _order_test("state_events", [room_id_3, room_id_2, room_id_1]) _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True) - def test_search_term(self) -> None: + def test_search_term(self): """Test that searching for a room works correctly""" # Create two test rooms room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1471,8 +1467,8 @@ class RoomTestCase(unittest.HomeserverTestCase): def _search_test( expected_room_id: Optional[str], search_term: str, - expected_http_code: int = HTTPStatus.OK, - ) -> None: + expected_http_code: int = 200, + ): """Search for a room and check that the returned room's id is a match Args: @@ -1489,7 +1485,7 @@ class RoomTestCase(unittest.HomeserverTestCase): ) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) - if expected_http_code != HTTPStatus.OK: + if expected_http_code != 200: return # Check that rooms were returned @@ -1532,7 +1528,7 @@ class RoomTestCase(unittest.HomeserverTestCase): _search_test(None, "foo") _search_test(None, "bar") - _search_test(None, "", expected_http_code=HTTPStatus.BAD_REQUEST) + _search_test(None, "", expected_http_code=400) # Test that the whole room id returns the room _search_test(room_id_1, room_id_1) @@ -1546,7 +1542,7 @@ class RoomTestCase(unittest.HomeserverTestCase): # Test search local part of alias _search_test(room_id_1, "alias1") - def test_search_term_non_ascii(self) -> None: + def test_search_term_non_ascii(self): """Test that searching for a room with non-ASCII characters works correctly""" # Create test room @@ -1569,11 +1565,11 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(room_id, channel.json_body.get("rooms")[0].get("room_id")) self.assertEqual("ж", channel.json_body.get("rooms")[0].get("name")) - def test_single_room(self) -> None: + def test_single_room(self): """Test that a single room can be requested correctly""" # Create two test rooms room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1602,7 +1598,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("room_id", channel.json_body) self.assertIn("name", channel.json_body) @@ -1624,7 +1620,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(room_id_1, channel.json_body["room_id"]) - def test_single_room_devices(self) -> None: + def test_single_room_devices(self): """Test that `joined_local_devices` can be requested correctly""" room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1634,7 +1630,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["joined_local_devices"]) # Have another user join the room @@ -1648,7 +1644,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(2, channel.json_body["joined_local_devices"]) # leave room @@ -1660,10 +1656,10 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["joined_local_devices"]) - def test_room_members(self) -> None: + def test_room_members(self): """Test that room members can be requested correctly""" # Create two test rooms room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1691,7 +1687,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertCountEqual( ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"] @@ -1704,14 +1700,14 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertCountEqual( ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"] ) self.assertEqual(channel.json_body["total"], 3) - def test_room_state(self) -> None: + def test_room_state(self): """Test that room state can be requested correctly""" # Create two test rooms room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1722,15 +1718,13 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("state", channel.json_body) # testing that the state events match is painful and not done here. We assume that # the create_room already does the right thing, so no need to verify that we got # the state events it created. - def _set_canonical_alias( - self, room_id: str, test_alias: str, admin_user_tok: str - ) -> None: + def _set_canonical_alias(self, room_id: str, test_alias: str, admin_user_tok: str): # Create a new alias to this room url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),) channel = self.make_request( @@ -1739,7 +1733,7 @@ class RoomTestCase(unittest.HomeserverTestCase): {"room_id": room_id}, access_token=admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Set this new alias as the canonical alias for this room self.helper.send_state( @@ -1765,7 +1759,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, homeserver): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -1780,117 +1774,124 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): ) self.url = f"/_synapse/admin/v1/join/{self.public_room_id}" - def test_requester_is_no_admin(self) -> None: + def test_requester_is_no_admin(self): """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ + body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( "POST", self.url, - content={"user_id": self.second_user_id}, + content=body, access_token=self.second_tok, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_invalid_parameter(self) -> None: + def test_invalid_parameter(self): """ If a parameter is missing, return an error """ + body = json.dumps({"unknown_parameter": "@unknown:test"}) channel = self.make_request( "POST", self.url, - content={"unknown_parameter": "@unknown:test"}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) - def test_local_user_does_not_exist(self) -> None: + def test_local_user_does_not_exist(self): """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ + body = json.dumps({"user_id": "@unknown:test"}) channel = self.make_request( "POST", self.url, - content={"user_id": "@unknown:test"}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_remote_user(self) -> None: + def test_remote_user(self): """ Check that only local user can join rooms. """ + body = json.dumps({"user_id": "@not:exist.bla"}) channel = self.make_request( "POST", self.url, - content={"user_id": "@not:exist.bla"}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "This endpoint can only be used with local users", channel.json_body["error"], ) - def test_room_does_not_exist(self) -> None: + def test_room_does_not_exist(self): """ - Check that unknown rooms/server return error HTTPStatus.NOT_FOUND. + Check that unknown rooms/server return error 404. """ + body = json.dumps({"user_id": self.second_user_id}) url = "/_synapse/admin/v1/join/!unknown:test" channel = self.make_request( "POST", url, - content={"user_id": self.second_user_id}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual("No known servers", channel.json_body["error"]) - def test_room_is_not_valid(self) -> None: + def test_room_is_not_valid(self): """ - Check that invalid room names, return an error HTTPStatus.BAD_REQUEST. + Check that invalid room names, return an error 400. """ + body = json.dumps({"user_id": self.second_user_id}) url = "/_synapse/admin/v1/join/invalidroom" channel = self.make_request( "POST", url, - content={"user_id": self.second_user_id}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "invalidroom was not legal room ID or room alias", channel.json_body["error"], ) - def test_join_public_room(self) -> None: + def test_join_public_room(self): """ Test joining a local user to a public room with "JoinRules.PUBLIC" """ + body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( "POST", self.url, - content={"user_id": self.second_user_id}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -1900,10 +1901,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) - def test_join_private_room_if_not_member(self) -> None: + def test_join_private_room_if_not_member(self): """ Test joining a local user to a private room with "JoinRules.INVITE" when server admin is not member of this room. @@ -1912,18 +1913,19 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.creator, tok=self.creator_tok, is_public=False ) url = f"/_synapse/admin/v1/join/{private_room_id}" + body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( "POST", url, - content={"user_id": self.second_user_id}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_join_private_room_if_member(self) -> None: + def test_join_private_room_if_member(self): """ Test joining a local user to a private room with "JoinRules.INVITE", when server admin is member of this room. @@ -1948,20 +1950,21 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) # Join user to room. url = f"/_synapse/admin/v1/join/{private_room_id}" + body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( "POST", url, - content={"user_id": self.second_user_id}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -1971,10 +1974,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) - def test_join_private_room_if_owner(self) -> None: + def test_join_private_room_if_owner(self): """ Test joining a local user to a private room with "JoinRules.INVITE", when server admin is owner of this room. @@ -1983,15 +1986,16 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.admin_user, tok=self.admin_user_tok, is_public=False ) url = f"/_synapse/admin/v1/join/{private_room_id}" + body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( "POST", url, - content={"user_id": self.second_user_id}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -2001,10 +2005,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) - def test_context_as_non_admin(self) -> None: + def test_context_as_non_admin(self): """ Test that, without being admin, one cannot use the context admin API """ @@ -2035,10 +2039,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=tok, ) - self.assertEquals(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEquals(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_context_as_admin(self) -> None: + def test_context_as_admin(self): """ Test that, as admin, we can find the context of an event without having joined the room. """ @@ -2065,7 +2069,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=self.admin_user_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEquals( channel.json_body["event"]["event_id"], events[midway]["event_id"] ) @@ -2094,7 +2098,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, homeserver): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -2111,7 +2115,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): self.public_room_id ) - def test_public_room(self) -> None: + def test_public_room(self): """Test that getting admin in a public room works.""" room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True @@ -2124,7 +2128,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Now we test that we can join the room and ban a user. self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok) @@ -2136,7 +2140,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): tok=self.admin_user_tok, ) - def test_private_room(self) -> None: + def test_private_room(self): """Test that getting admin in a private room works and we get invited.""" room_id = self.helper.create_room_as( self.creator, @@ -2151,7 +2155,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Now we test that we can join the room (we should have received an # invite) and can ban a user. @@ -2164,7 +2168,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): tok=self.admin_user_tok, ) - def test_other_user(self) -> None: + def test_other_user(self): """Test that giving admin in a public room works to a non-admin user works.""" room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True @@ -2177,7 +2181,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Now we test that we can join the room and ban a user. self.helper.join(room_id, self.second_user_id, tok=self.second_tok) @@ -2189,7 +2193,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): tok=self.second_tok, ) - def test_not_enough_power(self) -> None: + def test_not_enough_power(self): """Test that we get a sensible error if there are no local room admins.""" room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True @@ -2211,11 +2215,11 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - # We expect this to fail with a HTTPStatus.BAD_REQUEST as there are no room admins. + # We expect this to fail with a 400 as there are no room admins. # # (Note we assert the error message to ensure that it's not denied for # some other reason) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( channel.json_body["error"], "No local admin user in room with power to update power levels.", @@ -2229,7 +2233,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self._store = hs.get_datastore() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -2244,8 +2248,8 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/rooms/%s/block" @parameterized.expand([("PUT",), ("GET",)]) - def test_requester_is_no_admin(self, method: str) -> None: - """If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.""" + def test_requester_is_no_admin(self, method: str): + """If the user is not a server admin, an error 403 is returned.""" channel = self.make_request( method, @@ -2258,8 +2262,8 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand([("PUT",), ("GET",)]) - def test_room_is_not_valid(self, method: str) -> None: - """Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.""" + def test_room_is_not_valid(self, method: str): + """Check that invalid room names, return an error 400.""" channel = self.make_request( method, @@ -2274,7 +2278,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): channel.json_body["error"], ) - def test_block_is_not_valid(self) -> None: + def test_block_is_not_valid(self): """If parameter `block` is not valid, return an error.""" # `block` is not valid @@ -2309,7 +2313,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"]) - def test_block_room(self) -> None: + def test_block_room(self): """Test that block a room is successful.""" def _request_and_test_block_room(room_id: str) -> None: @@ -2333,7 +2337,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): # unknown remote room _request_and_test_block_room("!unknown:remote") - def test_block_room_twice(self) -> None: + def test_block_room_twice(self): """Test that block a room that is already blocked is successful.""" self._is_blocked(self.room_id, expect=False) @@ -2348,7 +2352,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): self.assertTrue(channel.json_body["block"]) self._is_blocked(self.room_id, expect=True) - def test_unblock_room(self) -> None: + def test_unblock_room(self): """Test that unblock a room is successful.""" def _request_and_test_unblock_room(room_id: str) -> None: @@ -2373,7 +2377,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): # unknown remote room _request_and_test_unblock_room("!unknown:remote") - def test_unblock_room_twice(self) -> None: + def test_unblock_room_twice(self): """Test that unblock a room that is not blocked is successful.""" self._block_room(self.room_id) @@ -2388,7 +2392,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): self.assertFalse(channel.json_body["block"]) self._is_blocked(self.room_id, expect=False) - def test_get_blocked_room(self) -> None: + def test_get_blocked_room(self): """Test get status of a blocked room""" def _request_blocked_room(room_id: str) -> None: @@ -2412,7 +2416,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): # unknown remote room _request_blocked_room("!unknown:remote") - def test_get_unblocked_room(self) -> None: + def test_get_unblocked_room(self): """Test get status of a unblocked room""" def _request_unblocked_room(room_id: str) -> None: diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index 3c59f5f766..fbceba3254 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -11,18 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus -from typing import List -from twisted.test.proto_helpers import MemoryReactor +from typing import List import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login, room, sync -from synapse.server import HomeServer from synapse.storage.roommember import RoomsForUser from synapse.types import JsonDict -from synapse.util import Clock from tests import unittest from tests.unittest import override_config @@ -37,7 +33,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.room_shutdown_handler = hs.get_room_shutdown_handler() self.pagination_handler = hs.get_pagination_handler() @@ -52,18 +48,14 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/send_server_notice" - def test_no_auth(self) -> None: + def test_no_auth(self): """Try to send a server notice without authentication.""" channel = self.make_request("POST", self.url) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self) -> None: + def test_requester_is_no_admin(self): """If the user is not a server admin, an error is returned.""" channel = self.make_request( "POST", @@ -71,16 +63,12 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_user_does_not_exist(self) -> None: - """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND""" + def test_user_does_not_exist(self): + """Tests that a lookup for a user that does not exist returns a 404""" channel = self.make_request( "POST", self.url, @@ -88,13 +76,13 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): content={"user_id": "@unknown_person:test", "content": ""}, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_user_is_not_local(self) -> None: + def test_user_is_not_local(self): """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ channel = self.make_request( "POST", @@ -106,13 +94,13 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "Server notices can only be sent to local users", channel.json_body["error"] ) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_invalid_parameter(self) -> None: + def test_invalid_parameter(self): """If parameters are invalid, an error is returned.""" # no content, no user @@ -122,7 +110,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"]) # no content @@ -133,7 +121,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): content={"user_id": self.other_user}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) # no body @@ -144,7 +132,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): content={"user_id": self.other_user, "content": ""}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual("'body' not in content", channel.json_body["error"]) @@ -156,11 +144,11 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): content={"user_id": self.other_user, "content": {"body": ""}}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual("'msgtype' not in content", channel.json_body["error"]) - def test_server_notice_disabled(self) -> None: + def test_server_notice_disabled(self): """Tests that server returns error if server notice is disabled""" channel = self.make_request( "POST", @@ -172,14 +160,14 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual( "Server notices are not enabled on this server", channel.json_body["error"] ) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_send_server_notice(self) -> None: + def test_send_server_notice(self): """ Tests that sending two server notices is successfully, the server uses the same room and do not send messages twice. @@ -197,7 +185,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg one"}, }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -228,7 +216,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg two"}, }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # user has no new invites or memberships self._check_invite_and_join_status(self.other_user, 0, 1) @@ -243,7 +231,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): self.assertEqual(messages[1]["sender"], "@notices:test") @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_send_server_notice_leave_room(self) -> None: + def test_send_server_notice_leave_room(self): """ Tests that sending a server notices is successfully. The user leaves the room and the second message appears @@ -262,7 +250,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg one"}, }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -305,7 +293,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg two"}, }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -327,7 +315,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): self.assertNotEqual(first_room_id, second_room_id) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_send_server_notice_delete_room(self) -> None: + def test_send_server_notice_delete_room(self): """ Tests that the user get server notice in a new room after the first server notice room was deleted. @@ -345,7 +333,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg one"}, }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -394,7 +382,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg two"}, }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -417,7 +405,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): def _check_invite_and_join_status( self, user_id: str, expected_invites: int, expected_memberships: int - ) -> List[RoomsForUser]: + ) -> RoomsForUser: """Check invite and room membership status of a user. Args @@ -452,7 +440,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "/_matrix/client/r0/sync", access_token=token ) - self.assertEqual(channel.code, HTTPStatus.OK) + self.assertEqual(channel.code, 200) # Get the messages room = channel.json_body["rooms"]["join"][room_id] diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index 7cb8ec57ba..ece89a65ac 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -12,17 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus -from typing import List, Optional -from twisted.test.proto_helpers import MemoryReactor +import json +from typing import Any, Dict, List, Optional import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login -from synapse.server import HomeServer -from synapse.types import JsonDict -from synapse.util import Clock from tests import unittest from tests.test_utils import SMALL_PNG @@ -34,7 +30,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.media_repo = hs.get_media_repository_resource() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -45,38 +41,30 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/statistics/users/media" - def test_no_auth(self) -> None: + def test_no_auth(self): """ Try to list users without authentication. """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self) -> None: + def test_requester_is_no_admin(self): """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ channel = self.make_request( "GET", self.url, - {}, + json.dumps({}), access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_invalid_parameter(self) -> None: + def test_invalid_parameter(self): """ If parameters are invalid, an error is returned. """ @@ -87,11 +75,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from @@ -101,11 +85,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative limit @@ -115,11 +95,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from_ts @@ -129,11 +105,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative until_ts @@ -143,11 +115,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # until_ts smaller from_ts @@ -157,11 +125,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # empty search term @@ -171,11 +135,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order @@ -185,14 +145,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - def test_limit(self) -> None: + def test_limit(self): """ Testing list of media with limit """ @@ -204,13 +160,13 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["users"]), 5) self.assertEqual(channel.json_body["next_token"], 5) self._check_fields(channel.json_body["users"]) - def test_from(self) -> None: + def test_from(self): """ Testing list of media with a defined starting point (from) """ @@ -222,13 +178,13 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["users"]), 15) self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["users"]) - def test_limit_and_from(self) -> None: + def test_limit_and_from(self): """ Testing list of media with a defined starting point and limit """ @@ -240,13 +196,13 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(len(channel.json_body["users"]), 10) self._check_fields(channel.json_body["users"]) - def test_next_token(self) -> None: + def test_next_token(self): """ Testing that `next_token` appears at the right place """ @@ -262,7 +218,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -275,7 +231,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -288,7 +244,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 19) self.assertEqual(channel.json_body["next_token"], 19) @@ -301,12 +257,12 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 1) self.assertNotIn("next_token", channel.json_body) - def test_no_media(self) -> None: + def test_no_media(self): """ Tests that a normal lookup for statistics is successfully if users have no media created @@ -318,11 +274,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["users"])) - def test_order_by(self) -> None: + def test_order_by(self): """ Testing order list with parameter `order_by` """ @@ -400,7 +356,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): "b", ) - def test_from_until_ts(self) -> None: + def test_from_until_ts(self): """ Testing filter by time with parameters `from_ts` and `until_ts` """ @@ -415,7 +371,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["users"][0]["media_count"], 3) # filter media starting at `ts1` after creating first media @@ -425,7 +381,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?from_ts=%s" % (ts1,), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 0) self._create_media(self.other_user_tok, 3) @@ -440,7 +396,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["users"][0]["media_count"], 3) # filter media until `ts2` and earlier @@ -449,10 +405,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?until_ts=%s" % (ts2,), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["users"][0]["media_count"], 6) - def test_search_term(self) -> None: + def test_search_term(self): self._create_users_with_media(20, 1) # check without filter get all users @@ -461,7 +417,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) # filter user 1 and 10-19 by `user_id` @@ -470,7 +426,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?search_term=foo_user_1", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 11) # filter on this user in `displayname` @@ -479,7 +435,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?search_term=bar_user_10", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["users"][0]["displayname"], "bar_user_10") self.assertEqual(channel.json_body["total"], 1) @@ -489,10 +445,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?search_term=foobar", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 0) - def _create_users_with_media(self, number_users: int, media_per_user: int) -> None: + def _create_users_with_media(self, number_users: int, media_per_user: int): """ Create a number of users with a number of media Args: @@ -504,7 +460,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): user_tok = self.login("foo_user_%s" % i, "pass") self._create_media(user_tok, media_per_user) - def _create_media(self, user_token: str, number_media: int) -> None: + def _create_media(self, user_token: str, number_media: int): """ Create a number of media for a specific user Args: @@ -515,10 +471,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): for _ in range(number_media): # Upload some media into the room self.helper.upload_media( - upload_resource, SMALL_PNG, tok=user_token, expect_code=HTTPStatus.OK + upload_resource, SMALL_PNG, tok=user_token, expect_code=200 ) - def _check_fields(self, content: List[JsonDict]) -> None: + def _check_fields(self, content: List[Dict[str, Any]]): """Checks that all attributes are present in content Args: content: List that is checked for content @@ -531,7 +487,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): def _order_test( self, order_type: str, expected_user_list: List[str], dir: Optional[str] = None - ) -> None: + ): """Request the list of users in a certain order. Assert that order is what we expect Args: @@ -549,7 +505,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], len(expected_user_list)) returned_order = [row["user_id"] for row in channel.json_body["users"]] diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 4fedd5fd08..5011e54563 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -17,7 +17,6 @@ import hmac import os import urllib.parse from binascii import unhexlify -from http import HTTPStatus from typing import List, Optional from unittest.mock import Mock, patch @@ -75,7 +74,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "Shared secret registration is not enabled", channel.json_body["error"] ) @@ -107,7 +106,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = {"nonce": nonce} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) # 61 seconds @@ -115,7 +114,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_register_incorrect_nonce(self): @@ -127,18 +126,18 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin") - want_mac_str = want_mac.hexdigest() + want_mac = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob", "password": "abc123", "admin": True, - "mac": want_mac_str, + "mac": want_mac, } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("HMAC incorrect", channel.json_body["error"]) def test_register_correct_nonce(self): @@ -153,7 +152,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update( nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support" ) - want_mac_str = want_mac.hexdigest() + want_mac = want_mac.hexdigest() body = { "nonce": nonce, @@ -161,11 +160,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "password": "abc123", "admin": True, "user_type": UserTypes.SUPPORT, - "mac": want_mac_str, + "mac": want_mac, } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) def test_nonce_reuse(self): @@ -177,24 +176,24 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin") - want_mac_str = want_mac.hexdigest() + want_mac = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob", "password": "abc123", "admin": True, - "mac": want_mac_str, + "mac": want_mac, } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) # Now, try and reuse it channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_missing_parts(self): @@ -215,7 +214,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be an empty body present channel = self.make_request("POST", self.url, {}) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("nonce must be specified", channel.json_body["error"]) # @@ -225,28 +224,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be present channel = self.make_request("POST", self.url, {"nonce": nonce()}) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) # Must be a string body = {"nonce": nonce(), "username": 1234} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = {"nonce": nonce(), "username": "abcd\u0000"} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = {"nonce": nonce(), "username": "a" * 1000} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # @@ -257,28 +256,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = {"nonce": nonce(), "username": "a"} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("password must be specified", channel.json_body["error"]) # Must be a string body = {"nonce": nonce(), "username": "a", "password": 1234} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Must not have null bytes body = {"nonce": nonce(), "username": "a", "password": "abcd\u0000"} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Super long body = {"nonce": nonce(), "username": "a", "password": "A" * 1000} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # @@ -294,7 +293,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid user type", channel.json_body["error"]) def test_displayname(self): @@ -308,22 +307,22 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob1\x00abc123\x00notadmin") - want_mac_str = want_mac.hexdigest() + want_mac = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob1", "password": "abc123", - "mac": want_mac_str, + "mac": want_mac, } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob1:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob1:test/displayname") - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("bob1", channel.json_body["displayname"]) # displayname is None @@ -332,22 +331,22 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob2\x00abc123\x00notadmin") - want_mac_str = want_mac.hexdigest() + want_mac = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob2", "displayname": None, "password": "abc123", - "mac": want_mac_str, + "mac": want_mac, } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob2:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob2:test/displayname") - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("bob2", channel.json_body["displayname"]) # displayname is empty @@ -356,22 +355,22 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob3\x00abc123\x00notadmin") - want_mac_str = want_mac.hexdigest() + want_mac = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob3", "displayname": "", "password": "abc123", - "mac": want_mac_str, + "mac": want_mac, } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob3:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob3:test/displayname") - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) # set displayname channel = self.make_request("GET", self.url) @@ -379,22 +378,22 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob4\x00abc123\x00notadmin") - want_mac_str = want_mac.hexdigest() + want_mac = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob4", "displayname": "Bob's Name", "password": "abc123", - "mac": want_mac_str, + "mac": want_mac, } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob4:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob4:test/displayname") - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("Bob's Name", channel.json_body["displayname"]) @override_config( @@ -426,7 +425,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update( nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support" ) - want_mac_str = want_mac.hexdigest() + want_mac = want_mac.hexdigest() body = { "nonce": nonce, @@ -434,11 +433,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "password": "abc123", "admin": True, "user_type": UserTypes.SUPPORT, - "mac": want_mac_str, + "mac": want_mac, } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -462,7 +461,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -474,7 +473,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", self.url, access_token=other_user_token) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_all_users(self): @@ -490,7 +489,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(3, len(channel.json_body["users"])) self.assertEqual(3, channel.json_body["total"]) @@ -504,7 +503,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): expected_user_id: Optional[str], search_term: str, search_field: Optional[str] = "name", - expected_http_code: Optional[int] = HTTPStatus.OK, + expected_http_code: Optional[int] = 200, ): """Search for a user and check that the returned user's id is a match @@ -526,7 +525,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): ) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) - if expected_http_code != HTTPStatus.OK: + if expected_http_code != 200: return # Check that users were returned @@ -587,7 +586,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from @@ -597,7 +596,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid guests @@ -607,7 +606,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # invalid deactivated @@ -617,7 +616,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # unkown order_by @@ -627,7 +626,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # invalid search order @@ -637,7 +636,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) def test_limit(self): @@ -655,7 +654,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 5) self.assertEqual(channel.json_body["next_token"], "5") @@ -676,7 +675,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 15) self.assertNotIn("next_token", channel.json_body) @@ -697,7 +696,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(channel.json_body["next_token"], "15") self.assertEqual(len(channel.json_body["users"]), 10) @@ -720,7 +719,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -733,7 +732,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -746,7 +745,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 19) self.assertEqual(channel.json_body["next_token"], "19") @@ -760,7 +759,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -863,14 +862,14 @@ class UsersListTestCase(unittest.HomeserverTestCase): url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], len(expected_user_list)) returned_order = [row["name"] for row in channel.json_body["users"]] self.assertEqual(expected_user_list, returned_order) self._check_fields(channel.json_body["users"]) - def _check_fields(self, content: List[JsonDict]): + def _check_fields(self, content: JsonDict): """Checks that the expected user attributes are present in content Args: content: List that is checked for content @@ -937,7 +936,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_not_admin(self): @@ -948,7 +947,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", url, access_token=self.other_user_token) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) channel = self.make_request( @@ -958,12 +957,12 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content=b"{}", ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) def test_user_does_not_exist(self): """ - Tests that deactivation for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that deactivation for a user that does not exist returns a 404 """ channel = self.make_request( @@ -972,7 +971,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_erase_is_not_bool(self): @@ -987,18 +986,18 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_user_is_not_local(self): """ - Tests that deactivation for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that deactivation for a user that is not a local returns a 400 """ url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain" channel = self.make_request("POST", url, access_token=self.admin_user_tok) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only deactivate local users", channel.json_body["error"]) def test_deactivate_user_erase_true(self): @@ -1013,7 +1012,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -1028,7 +1027,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content={"erase": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -1037,7 +1036,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -1058,7 +1057,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -1073,7 +1072,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content={"erase": False}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -1082,7 +1081,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -1112,7 +1111,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -1127,7 +1126,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content={"erase": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -1136,7 +1135,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -1196,7 +1195,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) channel = self.make_request( @@ -1206,12 +1205,12 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=b"{}", ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) def test_user_does_not_exist(self): """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ channel = self.make_request( @@ -1220,7 +1219,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) def test_invalid_parameter(self): @@ -1235,7 +1234,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"admin": "not_bool"}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) # deactivated not bool @@ -1245,7 +1244,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": "not_bool"}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # password not str @@ -1255,7 +1254,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"password": True}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # password not length @@ -1265,7 +1264,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"password": "x" * 513}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # user_type not valid @@ -1275,7 +1274,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"user_type": "new type"}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # external_ids not valid @@ -1287,7 +1286,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): "external_ids": {"auth_provider": "prov", "wrong_external_id": "id"} }, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) channel = self.make_request( @@ -1296,7 +1295,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"external_ids": {"external_id": "id"}}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) # threepids not valid @@ -1306,7 +1305,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"threepids": {"medium": "email", "wrong_address": "id"}}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) channel = self.make_request( @@ -1315,7 +1314,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"threepids": {"address": "value"}}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) def test_get_user(self): @@ -1328,7 +1327,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("User", channel.json_body["displayname"]) self._check_fields(channel.json_body) @@ -1371,7 +1370,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1434,7 +1433,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1462,9 +1461,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): # before limit of monthly active users is reached channel = self.make_request("GET", "/sync", access_token=self.admin_user_tok) - if channel.code != HTTPStatus.OK: + if channel.code != 200: raise HttpResponseException( - channel.code, channel.result["reason"], channel.json_body + channel.code, channel.result["reason"], channel.result["body"] ) # Set monthly active users to the limit @@ -1626,7 +1625,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"password": "hahaha"}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self._check_fields(channel.json_body) def test_set_displayname(self): @@ -1642,7 +1641,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"displayname": "foobar"}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("foobar", channel.json_body["displayname"]) @@ -1653,7 +1652,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("foobar", channel.json_body["displayname"]) @@ -1675,7 +1674,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["threepids"])) # result does not always have the same sort order, therefore it becomes sorted @@ -1701,7 +1700,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1717,7 +1716,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1733,7 +1732,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"threepids": []}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(0, len(channel.json_body["threepids"])) self._check_fields(channel.json_body) @@ -1760,7 +1759,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1779,7 +1778,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1801,7 +1800,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) # other user has this two threepids - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["threepids"])) # result does not always have the same sort order, therefore it becomes sorted @@ -1820,7 +1819,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): url_first_user, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(0, len(channel.json_body["threepids"])) self._check_fields(channel.json_body) @@ -1849,7 +1848,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["external_ids"])) # result does not always have the same sort order, therefore it becomes sorted @@ -1881,7 +1880,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["external_ids"])) self.assertEqual( @@ -1900,7 +1899,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["external_ids"])) self.assertEqual( @@ -1919,7 +1918,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"external_ids": []}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(0, len(channel.json_body["external_ids"])) @@ -1948,7 +1947,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual( @@ -1974,7 +1973,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual( @@ -2006,7 +2005,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) # must fail - self.assertEqual(HTTPStatus.CONFLICT, channel.code, msg=channel.json_body) + self.assertEqual(409, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual("External id is already in use.", channel.json_body["error"]) @@ -2017,7 +2016,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual( @@ -2035,7 +2034,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual( @@ -2066,7 +2065,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -2081,7 +2080,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"deactivated": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -2097,7 +2096,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -2124,7 +2123,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"deactivated": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) @@ -2140,7 +2139,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"displayname": "Foobar"}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertEqual("Foobar", channel.json_body["displayname"]) @@ -2164,7 +2163,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) # Reactivate the user. channel = self.make_request( @@ -2173,7 +2172,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertIsNotNone(channel.json_body["password_hash"]) @@ -2195,7 +2194,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Reactivate the user without a password. @@ -2205,7 +2204,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -2227,7 +2226,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Reactivate the user without a password. @@ -2237,7 +2236,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -2256,7 +2255,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"admin": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["admin"]) @@ -2267,7 +2266,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["admin"]) @@ -2284,7 +2283,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"user_type": UserTypes.SUPPORT}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"]) @@ -2295,7 +2294,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"]) @@ -2307,7 +2306,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"user_type": None}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertIsNone(channel.json_body["user_type"]) @@ -2318,7 +2317,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertIsNone(channel.json_body["user_type"]) @@ -2348,7 +2347,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) self.assertEqual(0, channel.json_body["deactivated"]) @@ -2361,7 +2360,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"password": "abc123", "deactivated": "false"}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) # Check user is not deactivated channel = self.make_request( @@ -2370,7 +2369,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) @@ -2395,7 +2394,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertTrue(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) self._is_erased(user_id, False) @@ -2446,7 +2445,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -2461,7 +2460,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_does_not_exist(self): @@ -2475,7 +2474,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["joined_rooms"])) @@ -2491,7 +2490,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["joined_rooms"])) @@ -2507,7 +2506,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["joined_rooms"])) @@ -2528,7 +2527,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_rooms, channel.json_body["total"]) self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"])) @@ -2575,7 +2574,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"]) @@ -2604,7 +2603,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -2619,12 +2618,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_does_not_exist(self): """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v1/users/@unknown_person:test/pushers" channel = self.make_request( @@ -2633,12 +2632,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_user_is_not_local(self): """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers" @@ -2648,7 +2647,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only look up local users", channel.json_body["error"]) def test_get_pushers(self): @@ -2663,7 +2662,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) # Register the pusher @@ -2694,7 +2693,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) for p in channel.json_body["pushers"]: @@ -2733,7 +2732,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): """Try to list media of an user without authentication.""" channel = self.make_request(method, self.url, {}) - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "DELETE"]) @@ -2747,12 +2746,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "DELETE"]) def test_user_does_not_exist(self, method: str): - """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND""" + """Tests that a lookup for a user that does not exist returns a 404""" url = "/_synapse/admin/v1/users/@unknown_person:test/media" channel = self.make_request( method, @@ -2760,12 +2759,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @parameterized.expand(["GET", "DELETE"]) def test_user_is_not_local(self, method: str): - """Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST""" + """Tests that a lookup for a user that is not a local returns a 400""" url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media" channel = self.make_request( @@ -2774,7 +2773,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only look up local users", channel.json_body["error"]) def test_limit_GET(self): @@ -2790,7 +2789,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 5) self.assertEqual(channel.json_body["next_token"], 5) @@ -2809,7 +2808,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 5) self.assertEqual(len(channel.json_body["deleted_media"]), 5) @@ -2826,7 +2825,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 15) self.assertNotIn("next_token", channel.json_body) @@ -2845,7 +2844,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 15) self.assertEqual(len(channel.json_body["deleted_media"]), 15) @@ -2862,7 +2861,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(len(channel.json_body["media"]), 10) @@ -2881,7 +2880,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["deleted_media"]), 10) @@ -2895,7 +2894,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # invalid search order @@ -2905,7 +2904,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # negative limit @@ -2915,7 +2914,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from @@ -2925,7 +2924,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_next_token(self): @@ -2948,7 +2947,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), number_media) self.assertNotIn("next_token", channel.json_body) @@ -2961,7 +2960,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), number_media) self.assertNotIn("next_token", channel.json_body) @@ -2974,7 +2973,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 19) self.assertEqual(channel.json_body["next_token"], 19) @@ -2988,7 +2987,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -3005,7 +3004,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["media"])) @@ -3020,7 +3019,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["deleted_media"])) @@ -3037,7 +3036,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_media, channel.json_body["total"]) self.assertEqual(number_media, len(channel.json_body["media"])) self.assertNotIn("next_token", channel.json_body) @@ -3063,7 +3062,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_media, channel.json_body["total"]) self.assertEqual(number_media, len(channel.json_body["deleted_media"])) self.assertCountEqual(channel.json_body["deleted_media"], media_ids) @@ -3208,7 +3207,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): # Upload some media into the room response = self.helper.upload_media( - upload_resource, image_data, user_token, filename, expect_code=HTTPStatus.OK + upload_resource, image_data, user_token, filename, expect_code=200 ) # Extract media ID from the response @@ -3226,16 +3225,16 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - HTTPStatus.OK, + 200, channel.code, msg=( - f"Expected to receive a HTTPStatus.OK on accessing media: {server_and_media_id}" + f"Expected to receive a 200 on accessing media: {server_and_media_id}" ), ) return media_id - def _check_fields(self, content: List[JsonDict]): + def _check_fields(self, content: JsonDict): """Checks that the expected user attributes are present in content Args: content: List that is checked for content @@ -3275,7 +3274,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], len(expected_media_list)) returned_order = [row["media_id"] for row in channel.json_body["media"]] @@ -3311,14 +3310,14 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, b"{}", access_token=self.admin_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) return channel.json_body["access_token"] def test_no_auth(self): """Try to login as a user without authentication.""" channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_not_admin(self): @@ -3327,7 +3326,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): "POST", self.url, b"{}", access_token=self.other_user_tok ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) def test_send_event(self): """Test that sending event as a user works.""" @@ -3352,7 +3351,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # We should only see the one device (from the login in `prepare`) self.assertEqual(len(channel.json_body["devices"]), 1) @@ -3364,21 +3363,21 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Logout with the puppet token channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) # .. but the real user's tokens should still work channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) def test_user_logout_all(self): """Tests that the target user calling `/logout/all` does *not* expire @@ -3389,23 +3388,23 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Logout all with the real user token channel = self.make_request( "POST", "logout/all", b"{}", access_token=self.other_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # The puppet token should still work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # .. but the real user's tokens shouldn't channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) def test_admin_logout_all(self): """Tests that the admin user calling `/logout/all` does expire the @@ -3416,23 +3415,23 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Logout all with the admin user token channel = self.make_request( "POST", "logout/all", b"{}", access_token=self.admin_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) # .. but the real user's tokens should still work channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) @unittest.override_config( { @@ -3460,10 +3459,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Now unaccept it and check that we can't send an event self.get_success(self.store.user_set_consent_version(self.other_user, "0.0")) self.helper.send_event( - room_id, - "com.example.test", - tok=self.other_user_tok, - expect_code=HTTPStatus.FORBIDDEN, + room_id, "com.example.test", tok=self.other_user_tok, expect_code=403 ) # Login in as the user @@ -3481,10 +3477,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Trying to join as the other user should fail due to reaching MAU limit. self.helper.join( - room_id, - user=self.other_user, - tok=self.other_user_tok, - expect_code=HTTPStatus.FORBIDDEN, + room_id, user=self.other_user, tok=self.other_user_tok, expect_code=403 ) # Logging in as the other user and joining a room should work, even @@ -3519,7 +3512,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): Try to get information of an user without authentication. """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_not_admin(self): @@ -3534,12 +3527,12 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.url, access_token=other_user2_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_is_not_local(self): """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ url = self.url_prefix % "@unknown_person:unknown_domain" @@ -3548,7 +3541,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only whois a local user", channel.json_body["error"]) def test_get_whois_admin(self): @@ -3560,7 +3553,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) @@ -3575,7 +3568,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.url, access_token=other_user_token, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) @@ -3605,7 +3598,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): Try to get information of an user without authentication. """ channel = self.make_request(method, self.url) - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["POST", "DELETE"]) @@ -3616,18 +3609,18 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): other_user_token = self.login("user", "pass") channel = self.make_request(method, self.url, access_token=other_user_token) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["POST", "DELETE"]) def test_user_is_not_local(self, method: str): """ - Tests that shadow-banning for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that shadow-banning for a user that is not a local returns a 400 """ url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" channel = self.make_request(method, url, access_token=self.admin_user_tok) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) def test_success(self): """ @@ -3639,7 +3632,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): self.assertFalse(result.shadow_banned) channel = self.make_request("POST", self.url, access_token=self.admin_user_tok) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual({}, channel.json_body) # Ensure the user is shadow-banned (and the cache was cleared). @@ -3650,7 +3643,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, access_token=self.admin_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual({}, channel.json_body) # Ensure the user is no longer shadow-banned (and the cache was cleared). @@ -3684,7 +3677,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): """ channel = self.make_request(method, self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "POST", "DELETE"]) @@ -3700,13 +3693,13 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "POST", "DELETE"]) def test_user_does_not_exist(self, method: str): """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit" @@ -3716,7 +3709,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @parameterized.expand( @@ -3728,7 +3721,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): ) def test_user_is_not_local(self, method: str, error_msg: str): """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ url = ( "/_synapse/admin/v1/users/@unknown_person:unknown_domain/override_ratelimit" @@ -3740,7 +3733,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(error_msg, channel.json_body["error"]) def test_invalid_parameter(self): @@ -3755,7 +3748,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"messages_per_second": "string"}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # messages_per_second is negative @@ -3766,7 +3759,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"messages_per_second": -1}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # burst_count is a string @@ -3777,7 +3770,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"burst_count": "string"}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # burst_count is negative @@ -3788,7 +3781,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"burst_count": -1}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_return_zero_when_null(self): @@ -3813,7 +3806,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["messages_per_second"]) self.assertEqual(0, channel.json_body["burst_count"]) @@ -3827,7 +3820,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) @@ -3838,7 +3831,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"messages_per_second": 10, "burst_count": 11}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(10, channel.json_body["messages_per_second"]) self.assertEqual(11, channel.json_body["burst_count"]) @@ -3849,7 +3842,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"messages_per_second": 20, "burst_count": 21}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(20, channel.json_body["messages_per_second"]) self.assertEqual(21, channel.json_body["burst_count"]) @@ -3859,7 +3852,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(20, channel.json_body["messages_per_second"]) self.assertEqual(21, channel.json_body["burst_count"]) @@ -3869,7 +3862,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) @@ -3879,6 +3872,6 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py index 7978626e71..4e1c49c28b 100644 --- a/tests/rest/admin/test_username_available.py +++ b/tests/rest/admin/test_username_available.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus - import synapse.rest.admin from synapse.api.errors import Codes, SynapseError from synapse.rest.client import login @@ -35,38 +33,30 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase): async def check_username(username): if username == "allowed": return True - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "User ID already taken.", - errcode=Codes.USER_IN_USE, - ) + raise SynapseError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) handler = self.hs.get_registration_handler() handler.check_username = check_username def test_username_available(self): """ - The endpoint should return a HTTPStatus.OK response if the username does not exist + The endpoint should return a 200 response if the username does not exist """ url = "%s?username=%s" % (self.url, "allowed") channel = self.make_request("GET", url, None, self.admin_user_tok) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertTrue(channel.json_body["available"]) def test_username_unavailable(self): """ - The endpoint should return a HTTPStatus.OK response if the username does not exist + The endpoint should return a 200 response if the username does not exist """ url = "%s?username=%s" % (self.url, "disallowed") channel = self.make_request("GET", url, None, self.admin_user_tok) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], "M_USER_IN_USE") self.assertEqual(channel.json_body["error"], "User ID already taken.") diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 72bbc87b4a..8552671431 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus from typing import Optional, Union from twisted.internet.defer import succeed @@ -514,39 +513,12 @@ class RefreshAuthTests(unittest.HomeserverTestCase): self.user_pass = "pass" self.user = self.register_user("test", self.user_pass) - def use_refresh_token(self, refresh_token: str) -> FakeChannel: - """ - Helper that makes a request to use a refresh token. - """ - return self.make_request( - "POST", - "/_matrix/client/v1/refresh", - {"refresh_token": refresh_token}, - ) - - def is_access_token_valid(self, access_token) -> bool: - """ - Checks whether an access token is valid, returning whether it is or not. - """ - code = self.make_request( - "GET", "/_matrix/client/v3/account/whoami", access_token=access_token - ).code - - # Either 200 or 401 is what we get back; anything else is a bug. - assert code in {HTTPStatus.OK, HTTPStatus.UNAUTHORIZED} - - return code == HTTPStatus.OK - def test_login_issue_refresh_token(self): """ A login response should include a refresh_token only if asked. """ # Test login - body = { - "type": "m.login.password", - "user": "test", - "password": self.user_pass, - } + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} login_without_refresh = self.make_request( "POST", "/_matrix/client/r0/login", body @@ -556,8 +528,8 @@ class RefreshAuthTests(unittest.HomeserverTestCase): login_with_refresh = self.make_request( "POST", - "/_matrix/client/r0/login", - {"refresh_token": True, **body}, + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + body, ) self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result) self.assertIn("refresh_token", login_with_refresh.json_body) @@ -583,12 +555,11 @@ class RefreshAuthTests(unittest.HomeserverTestCase): register_with_refresh = self.make_request( "POST", - "/_matrix/client/r0/register", + "/_matrix/client/r0/register?org.matrix.msc2918.refresh_token=true", { "username": "test3", "password": self.user_pass, "auth": {"type": LoginType.DUMMY}, - "refresh_token": True, }, ) self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result) @@ -599,22 +570,17 @@ class RefreshAuthTests(unittest.HomeserverTestCase): """ A refresh token can be used to issue a new access token. """ - body = { - "type": "m.login.password", - "user": "test", - "password": self.user_pass, - "refresh_token": True, - } + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} login_response = self.make_request( "POST", - "/_matrix/client/r0/login", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", body, ) self.assertEqual(login_response.code, 200, login_response.result) refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual(refresh_response.code, 200, refresh_response.result) @@ -633,19 +599,14 @@ class RefreshAuthTests(unittest.HomeserverTestCase): ) @override_config({"refreshable_access_token_lifetime": "1m"}) - def test_refreshable_access_token_expiration(self): + def test_refresh_token_expiration(self): """ The access token should have some time as specified in the config. """ - body = { - "type": "m.login.password", - "user": "test", - "password": self.user_pass, - "refresh_token": True, - } + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} login_response = self.make_request( "POST", - "/_matrix/client/r0/login", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", body, ) self.assertEqual(login_response.code, 200, login_response.result) @@ -655,198 +616,13 @@ class RefreshAuthTests(unittest.HomeserverTestCase): refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual(refresh_response.code, 200, refresh_response.result) self.assertApproximates( refresh_response.json_body["expires_in_ms"], 60 * 1000, 100 ) - access_token = refresh_response.json_body["access_token"] - - # Advance 59 seconds in the future (just shy of 1 minute, the time of expiry) - self.reactor.advance(59.0) - # Check that our token is valid - self.assertEqual( - self.make_request( - "GET", "/_matrix/client/v3/account/whoami", access_token=access_token - ).code, - HTTPStatus.OK, - ) - - # Advance 2 more seconds (just past the time of expiry) - self.reactor.advance(2.0) - # Check that our token is invalid - self.assertEqual( - self.make_request( - "GET", "/_matrix/client/v3/account/whoami", access_token=access_token - ).code, - HTTPStatus.UNAUTHORIZED, - ) - - @override_config( - { - "refreshable_access_token_lifetime": "1m", - "nonrefreshable_access_token_lifetime": "10m", - } - ) - def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self): - """ - Tests that the expiry times for refreshable and non-refreshable access - tokens can be different. - """ - body = { - "type": "m.login.password", - "user": "test", - "password": self.user_pass, - } - login_response1 = self.make_request( - "POST", - "/_matrix/client/r0/login", - {"refresh_token": True, **body}, - ) - self.assertEqual(login_response1.code, 200, login_response1.result) - self.assertApproximates( - login_response1.json_body["expires_in_ms"], 60 * 1000, 100 - ) - refreshable_access_token = login_response1.json_body["access_token"] - - login_response2 = self.make_request( - "POST", - "/_matrix/client/r0/login", - body, - ) - self.assertEqual(login_response2.code, 200, login_response2.result) - nonrefreshable_access_token = login_response2.json_body["access_token"] - - # Advance 59 seconds in the future (just shy of 1 minute, the time of expiry) - self.reactor.advance(59.0) - - # Both tokens should still be valid. - self.assertTrue(self.is_access_token_valid(refreshable_access_token)) - self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) - - # Advance to 61 s (just past 1 minute, the time of expiry) - self.reactor.advance(2.0) - - # Only the non-refreshable token is still valid. - self.assertFalse(self.is_access_token_valid(refreshable_access_token)) - self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) - - # Advance to 599 s (just shy of 10 minutes, the time of expiry) - self.reactor.advance(599.0 - 61.0) - - # It's still the case that only the non-refreshable token is still valid. - self.assertFalse(self.is_access_token_valid(refreshable_access_token)) - self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) - - # Advance to 601 s (just past 10 minutes, the time of expiry) - self.reactor.advance(2.0) - - # Now neither token is valid. - self.assertFalse(self.is_access_token_valid(refreshable_access_token)) - self.assertFalse(self.is_access_token_valid(nonrefreshable_access_token)) - - @override_config( - {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"} - ) - def test_refresh_token_expiry(self): - """ - The refresh token can be configured to have a limited lifetime. - When that lifetime has ended, the refresh token can no longer be used to - refresh the session. - """ - - body = { - "type": "m.login.password", - "user": "test", - "password": self.user_pass, - "refresh_token": True, - } - login_response = self.make_request( - "POST", - "/_matrix/client/r0/login", - body, - ) - self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) - refresh_token1 = login_response.json_body["refresh_token"] - - # Advance 119 seconds in the future (just shy of 2 minutes) - self.reactor.advance(119.0) - - # Refresh our session. The refresh token should still JUST be valid right now. - # By doing so, we get a new access token and a new refresh token. - refresh_response = self.use_refresh_token(refresh_token1) - self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) - self.assertIn( - "refresh_token", - refresh_response.json_body, - "No new refresh token returned after refresh.", - ) - refresh_token2 = refresh_response.json_body["refresh_token"] - - # Advance 121 seconds in the future (just a bit more than 2 minutes) - self.reactor.advance(121.0) - - # Try to refresh our session, but instead notice that the refresh token is - # not valid (it just expired). - refresh_response = self.use_refresh_token(refresh_token2) - self.assertEqual( - refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result - ) - - @override_config( - { - "refreshable_access_token_lifetime": "2m", - "refresh_token_lifetime": "2m", - "session_lifetime": "3m", - } - ) - def test_ultimate_session_expiry(self): - """ - The session can be configured to have an ultimate, limited lifetime. - """ - - body = { - "type": "m.login.password", - "user": "test", - "password": self.user_pass, - "refresh_token": True, - } - login_response = self.make_request( - "POST", - "/_matrix/client/r0/login", - body, - ) - self.assertEqual(login_response.code, 200, login_response.result) - refresh_token = login_response.json_body["refresh_token"] - - # Advance shy of 2 minutes into the future - self.reactor.advance(119.0) - - # Refresh our session. The refresh token should still be valid right now. - refresh_response = self.use_refresh_token(refresh_token) - self.assertEqual(refresh_response.code, 200, refresh_response.result) - self.assertIn( - "refresh_token", - refresh_response.json_body, - "No new refresh token returned after refresh.", - ) - # Notice that our access token lifetime has been diminished to match the - # session lifetime. - # 3 minutes - 119 seconds = 61 seconds. - self.assertEqual(refresh_response.json_body["expires_in_ms"], 61_000) - refresh_token = refresh_response.json_body["refresh_token"] - - # Advance 61 seconds into the future. Our session should have expired - # now, because we've had our 3 minutes. - self.reactor.advance(61.0) - - # Try to issue a new, refreshed, access token. - # This should fail because the refresh token's lifetime has also been - # diminished as our session expired. - refresh_response = self.use_refresh_token(refresh_token) - self.assertEqual(refresh_response.code, 403, refresh_response.result) def test_refresh_token_invalidation(self): """Refresh tokens are invalidated after first use of the next token. @@ -864,15 +640,10 @@ class RefreshAuthTests(unittest.HomeserverTestCase): |-> fourth_refresh (fails) """ - body = { - "type": "m.login.password", - "user": "test", - "password": self.user_pass, - "refresh_token": True, - } + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} login_response = self.make_request( "POST", - "/_matrix/client/r0/login", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", body, ) self.assertEqual(login_response.code, 200, login_response.result) @@ -880,7 +651,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This first refresh should work properly first_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -890,7 +661,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This one as well, since the token in the first one was never used second_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -900,7 +671,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This one should not, since the token from the first refresh is not valid anymore third_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", {"refresh_token": first_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -928,7 +699,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail fourth_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -938,7 +709,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # But refreshing from the last valid refresh token still works fifth_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", {"refresh_token": second_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 397c12c2a6..eb10d43217 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -19,7 +19,7 @@ from typing import Dict, List, Optional, Tuple from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin -from synapse.rest.client import login, register, relations, room, sync +from synapse.rest.client import login, register, relations, room from tests import unittest from tests.server import FakeChannel @@ -29,7 +29,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): servlets = [ relations.register_servlets, room.register_servlets, - sync.register_servlets, login.register_servlets, register.register_servlets, admin.register_servlets_for_client_rest_resource, @@ -455,9 +454,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(400, channel.code, channel.json_body) @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) - def test_bundled_aggregations(self): - """Test that annotations, references, and threads get correctly bundled.""" - # Setup by sending a variety of relations. + def test_aggregation_get_event(self): + """Test that annotations, references, and threads get correctly bundled when + getting the parent event. + """ + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) @@ -484,169 +485,43 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) thread_2 = channel.json_body["event_id"] - def assert_bundle(actual): - """Assert the expected values of the bundled aggregations.""" - - # Ensure the fields are as expected. - self.assertCountEqual( - actual.keys(), - ( - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - RelationTypes.THREAD, - ), - ) - - # Check the values of each field. - self.assertEquals( - { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - actual[RelationTypes.ANNOTATION], - ) - - self.assertEquals( - {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, - actual[RelationTypes.REFERENCE], - ) - - self.assertEquals( - 2, - actual[RelationTypes.THREAD].get("count"), - ) - # The latest thread event has some fields that don't matter. - self.assert_dict( - { - "content": { - "m.relates_to": { - "event_id": self.parent_id, - "rel_type": RelationTypes.THREAD, - } - }, - "event_id": thread_2, - "room_id": self.room, - "sender": self.user_id, - "type": "m.room.test", - "user_id": self.user_id, - }, - actual[RelationTypes.THREAD].get("latest_event"), - ) - - def _find_and_assert_event(events): - """ - Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. - """ - for event in events: - if event["event_id"] == self.parent_id: - break - else: - raise AssertionError(f"Event {self.parent_id} not found in chunk") - assert_bundle(event["unsigned"].get("m.relations")) - - # Request the event directly. channel = self.make_request( "GET", - f"/rooms/{self.room}/event/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - assert_bundle(channel.json_body["unsigned"].get("m.relations")) - - # Request the room messages. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/messages?dir=b", - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - _find_and_assert_event(channel.json_body["chunk"]) - - # Request the room context. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/context/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - assert_bundle(channel.json_body["event"]["unsigned"].get("m.relations")) - - # Request sync. - channel = self.make_request("GET", "/sync", access_token=self.user_token) - self.assertEquals(200, channel.code, channel.json_body) - room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] - self.assertTrue(room_timeline["limited"]) - _find_and_assert_event(room_timeline["events"]) - - # Note that /relations is tested separately in test_aggregation_get_event_for_thread - # since it needs different data configured. - - def test_aggregation_get_event_for_annotation(self): - """Test that annotations do not get bundled aggregations included - when directly requested. - """ - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) - annotation_id = channel.json_body["event_id"] - - # Annotate the annotation. - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id - ) - self.assertEquals(200, channel.code, channel.json_body) - - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{annotation_id}", + "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) - def test_aggregation_get_event_for_thread(self): - """Test that threads get bundled aggregations included when directly requested.""" - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) - thread_id = channel.json_body["event_id"] - - # Annotate the annotation. - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id - ) - self.assertEquals(200, channel.code, channel.json_body) - - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{thread_id}", - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) self.assertEquals( channel.json_body["unsigned"].get("m.relations"), { RelationTypes.ANNOTATION: { - "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] }, - }, - ) - - # It should also be included when the entire thread is requested. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - self.assertEqual(len(channel.json_body["chunk"]), 1) - - thread_message = channel.json_body["chunk"][0] - self.assertEquals( - thread_message["unsigned"].get("m.relations"), - { - RelationTypes.ANNOTATION: { - "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] + RelationTypes.REFERENCE: { + "chunk": [{"event_id": reply_1}, {"event_id": reply_2}] + }, + RelationTypes.THREAD: { + "count": 2, + "latest_event": { + "age": 100, + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } + }, + "event_id": thread_2, + "origin_server_ts": 1600, + "room_id": self.room, + "sender": self.user_id, + "type": "m.room.test", + "unsigned": {"age": 100}, + "user_id": self.user_id, + }, }, }, ) @@ -797,56 +672,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) - def test_edit_edit(self): - """Test that an edit cannot be edited.""" - new_body = {"msgtype": "m.text", "body": "Initial edit"} - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={ - "msgtype": "m.text", - "body": "Wibble", - "m.new_content": new_body, - }, - ) - self.assertEquals(200, channel.code, channel.json_body) - edit_event_id = channel.json_body["event_id"] - - # Edit the edit event. - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={ - "msgtype": "m.text", - "body": "foo", - "m.new_content": {"msgtype": "m.text", "body": "Ignored edit"}, - }, - parent_id=edit_event_id, - ) - self.assertEquals(200, channel.code, channel.json_body) - - # Request the original event. - channel = self.make_request( - "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - # The edit to the edit should be ignored. - self.assertEquals(channel.json_body["content"], new_body) - - # The relations information should not include the edit to the edit. - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) - - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) - - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict - ) - def test_relations_redaction_redacts_edits(self): """Test that edits of an event are redacted when the original event is redacted. diff --git a/tests/rest/media/v1/test_filepath.py b/tests/rest/media/v1/test_filepath.py index 913bc530aa..8fe94f7d85 100644 --- a/tests/rest/media/v1/test_filepath.py +++ b/tests/rest/media/v1/test_filepath.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -import os from typing import Iterable -from synapse.rest.media.v1.filepath import MediaFilePaths, _wrap_with_jail_check +from synapse.rest.media.v1.filepath import MediaFilePaths from tests import unittest @@ -487,109 +486,3 @@ class MediaFilePathsTestCase(unittest.TestCase): f"{value!r} unexpectedly passed validation: " f"{method} returned {path_or_list!r}" ) - - -class MediaFilePathsJailTestCase(unittest.TestCase): - def _check_relative_path(self, filepaths: MediaFilePaths, path: str) -> None: - """Passes a relative path through the jail check. - - Args: - filepaths: The `MediaFilePaths` instance. - path: A path relative to the media store directory. - - Raises: - ValueError: If the jail check fails. - """ - - @_wrap_with_jail_check(relative=True) - def _make_relative_path(self: MediaFilePaths, path: str) -> str: - return path - - _make_relative_path(filepaths, path) - - def _check_absolute_path(self, filepaths: MediaFilePaths, path: str) -> None: - """Passes an absolute path through the jail check. - - Args: - filepaths: The `MediaFilePaths` instance. - path: A path relative to the media store directory. - - Raises: - ValueError: If the jail check fails. - """ - - @_wrap_with_jail_check(relative=False) - def _make_absolute_path(self: MediaFilePaths, path: str) -> str: - return os.path.join(self.base_path, path) - - _make_absolute_path(filepaths, path) - - def test_traversal_inside(self) -> None: - """Test the jail check for paths that stay within the media directory.""" - # Despite the `../`s, these paths still lie within the media directory and it's - # expected for the jail check to allow them through. - # These paths ought to trip the other checks in place and should never be - # returned. - filepaths = MediaFilePaths("/media_store") - path = "url_cache/2020-01-02/../../GerZNDnDZVjsOtar" - self._check_relative_path(filepaths, path) - self._check_absolute_path(filepaths, path) - - def test_traversal_outside(self) -> None: - """Test that the jail check fails for paths that escape the media directory.""" - filepaths = MediaFilePaths("/media_store") - path = "url_cache/2020-01-02/../../../GerZNDnDZVjsOtar" - with self.assertRaises(ValueError): - self._check_relative_path(filepaths, path) - with self.assertRaises(ValueError): - self._check_absolute_path(filepaths, path) - - def test_traversal_reentry(self) -> None: - """Test the jail check for paths that exit and re-enter the media directory.""" - # These paths lie outside the media directory if it is a symlink, and inside - # otherwise. Ideally the check should fail, but this proves difficult. - # This test documents the behaviour for this edge case. - # These paths ought to trip the other checks in place and should never be - # returned. - filepaths = MediaFilePaths("/media_store") - path = "url_cache/2020-01-02/../../../media_store/GerZNDnDZVjsOtar" - self._check_relative_path(filepaths, path) - self._check_absolute_path(filepaths, path) - - def test_symlink(self) -> None: - """Test that a symlink does not cause the jail check to fail.""" - media_store_path = self.mktemp() - - # symlink the media store directory - os.symlink("/mnt/synapse/media_store", media_store_path) - - # Test that relative and absolute paths don't trip the check - # NB: `media_store_path` is a relative path - filepaths = MediaFilePaths(media_store_path) - self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") - self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") - - filepaths = MediaFilePaths(os.path.abspath(media_store_path)) - self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") - self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") - - def test_symlink_subdirectory(self) -> None: - """Test that a symlinked subdirectory does not cause the jail check to fail.""" - media_store_path = self.mktemp() - os.mkdir(media_store_path) - - # symlink `url_cache/` - os.symlink( - "/mnt/synapse/media_store_url_cache", - os.path.join(media_store_path, "url_cache"), - ) - - # Test that relative and absolute paths don't trip the check - # NB: `media_store_path` is a relative path - filepaths = MediaFilePaths(media_store_path) - self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") - self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") - - filepaths = MediaFilePaths(os.path.abspath(media_store_path)) - self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") - self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 5ae491ff5a..a649e8c618 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -12,24 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -from contextlib import contextmanager -from typing import Generator -from twisted.enterprise.adbapi import ConnectionPool -from twisted.internet.defer import ensureDeferred -from twisted.test.proto_helpers import MemoryReactor - -from synapse.api.room_versions import EventFormatVersions, RoomVersions from synapse.logging.context import LoggingContext from synapse.rest import admin from synapse.rest.client import login, room -from synapse.server import HomeServer -from synapse.storage.databases.main.events_worker import ( - EVENT_QUEUE_THREADS, - EventsWorkerStore, -) -from synapse.storage.types import Connection -from synapse.util import Clock +from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.util.async_helpers import yieldable_gather_results from tests import unittest @@ -157,127 +144,3 @@ class EventCacheTestCase(unittest.HomeserverTestCase): # We should have fetched the event from the DB self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) - - -class DatabaseOutageTestCase(unittest.HomeserverTestCase): - """Test event fetching during a database outage.""" - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): - self.store: EventsWorkerStore = hs.get_datastore() - - self.room_id = f"!room:{hs.hostname}" - self.event_ids = [f"event{i}" for i in range(20)] - - self._populate_events() - - def _populate_events(self) -> None: - """Ensure that there are test events in the database. - - When testing with the in-memory SQLite database, all the events are lost during - the simulated outage. - - To ensure consistency between `room_id`s and `event_id`s before and after the - outage, rows are built and inserted manually. - - Upserts are used to handle the non-SQLite case where events are not lost. - """ - self.get_success( - self.store.db_pool.simple_upsert( - "rooms", - {"room_id": self.room_id}, - {"room_version": RoomVersions.V4.identifier}, - ) - ) - - self.event_ids = [f"event{i}" for i in range(20)] - for idx, event_id in enumerate(self.event_ids): - self.get_success( - self.store.db_pool.simple_upsert( - "events", - {"event_id": event_id}, - { - "event_id": event_id, - "room_id": self.room_id, - "topological_ordering": idx, - "stream_ordering": idx, - "type": "test", - "processed": True, - "outlier": False, - }, - ) - ) - self.get_success( - self.store.db_pool.simple_upsert( - "event_json", - {"event_id": event_id}, - { - "room_id": self.room_id, - "json": json.dumps({"type": "test", "room_id": self.room_id}), - "internal_metadata": "{}", - "format_version": EventFormatVersions.V3, - }, - ) - ) - - @contextmanager - def _outage(self) -> Generator[None, None, None]: - """Simulate a database outage. - - Returns: - A context manager. While the context is active, any attempts to connect to - the database will fail. - """ - connection_pool = self.store.db_pool._db_pool - - # Close all connections and shut down the database `ThreadPool`. - connection_pool.close() - - # Restart the database `ThreadPool`. - connection_pool.start() - - original_connection_factory = connection_pool.connectionFactory - - def connection_factory(_pool: ConnectionPool) -> Connection: - raise Exception("Could not connect to the database.") - - connection_pool.connectionFactory = connection_factory # type: ignore[assignment] - try: - yield - finally: - connection_pool.connectionFactory = original_connection_factory - - # If the in-memory SQLite database is being used, all the events are gone. - # Restore the test data. - self._populate_events() - - def test_failure(self) -> None: - """Test that event fetches do not get stuck during a database outage.""" - with self._outage(): - failure = self.get_failure( - self.store.get_event(self.event_ids[0]), Exception - ) - self.assertEqual(str(failure.value), "Could not connect to the database.") - - def test_recovery(self) -> None: - """Test that event fetchers recover after a database outage.""" - with self._outage(): - # Kick off a bunch of event fetches but do not pump the reactor - event_deferreds = [] - for event_id in self.event_ids: - event_deferreds.append(ensureDeferred(self.store.get_event(event_id))) - - # We should have maxed out on event fetcher threads - self.assertEqual(self.store._event_fetch_ongoing, EVENT_QUEUE_THREADS) - - # All the event fetchers will fail - self.pump() - self.assertEqual(self.store._event_fetch_ongoing, 0) - - for event_deferred in event_deferreds: - failure = self.get_failure(event_deferred, Exception) - self.assertEqual( - str(failure.value), "Could not connect to the database." - ) - - # This next event fetch should succeed - self.get_success(self.store.get_event(self.event_ids[0])) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 329490caad..f26d5acf9c 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -14,37 +14,35 @@ import json import os import tempfile -from typing import List, Optional, cast from unittest.mock import Mock import yaml from twisted.internet import defer -from twisted.test.proto_helpers import MemoryReactor from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.config._base import ConfigError -from synapse.events import EventBase -from synapse.server import HomeServer from synapse.storage.database import DatabasePool, make_conn from synapse.storage.databases.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) -from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable +from tests.utils import setup_test_homeserver -class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): +class ApplicationServiceStoreTestCase(unittest.TestCase): + @defer.inlineCallbacks def setUp(self): - super(ApplicationServiceStoreTestCase, self).setUp() - - self.as_yaml_files: List[str] = [] + self.as_yaml_files = [] + hs = yield setup_test_homeserver( + self.addCleanup, federation_sender=Mock(), federation_client=Mock() + ) - self.hs.config.appservice.app_service_config_files = self.as_yaml_files - self.hs.config.caches.event_cache_size = 1 + hs.config.appservice.app_service_config_files = self.as_yaml_files + hs.config.caches.event_cache_size = 1 self.as_token = "token1" self.as_url = "some_url" @@ -55,14 +53,12 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") # must be done after inserts - database = self.hs.get_datastores().databases[0] + database = hs.get_datastores().databases[0] self.store = ApplicationServiceStore( - database, - make_conn(database._database_config, database.engine, "test"), - self.hs, + database, make_conn(database._database_config, database.engine, "test"), hs ) - def tearDown(self) -> None: + def tearDown(self): # TODO: suboptimal that we need to create files for tests! for f in self.as_yaml_files: try: @@ -70,9 +66,7 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): except Exception: pass - super(ApplicationServiceStoreTestCase, self).tearDown() - - def _add_appservice(self, as_token, id, url, hs_token, sender) -> None: + def _add_appservice(self, as_token, id, url, hs_token, sender): as_yaml = { "url": url, "as_token": as_token, @@ -86,13 +80,12 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): outfile.write(yaml.dump(as_yaml)) self.as_yaml_files.append(as_token) - def test_retrieve_unknown_service_token(self) -> None: + def test_retrieve_unknown_service_token(self): service = self.store.get_app_service_by_token("invalid_token") self.assertEquals(service, None) - def test_retrieval_of_service(self) -> None: + def test_retrieval_of_service(self): stored_service = self.store.get_app_service_by_token(self.as_token) - assert stored_service is not None self.assertEquals(stored_service.token, self.as_token) self.assertEquals(stored_service.id, self.as_id) self.assertEquals(stored_service.url, self.as_url) @@ -100,18 +93,22 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): self.assertEquals(stored_service.namespaces[ApplicationService.NS_ROOMS], []) self.assertEquals(stored_service.namespaces[ApplicationService.NS_USERS], []) - def test_retrieval_of_all_services(self) -> None: + def test_retrieval_of_all_services(self): services = self.store.get_app_services() self.assertEquals(len(services), 3) -class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): - def setUp(self) -> None: - super(ApplicationServiceTransactionStoreTestCase, self).setUp() - self.as_yaml_files: List[str] = [] +class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.as_yaml_files = [] + + hs = yield setup_test_homeserver( + self.addCleanup, federation_sender=Mock(), federation_client=Mock() + ) - self.hs.config.appservice.app_service_config_files = self.as_yaml_files - self.hs.config.caches.event_cache_size = 1 + hs.config.appservice.app_service_config_files = self.as_yaml_files + hs.config.caches.event_cache_size = 1 self.as_list = [ {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"}, @@ -120,21 +117,21 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): {"token": "gamma_tok", "url": "https://gamma.com", "id": "id_gamma"}, ] for s in self.as_list: - self._add_service(s["url"], s["token"], s["id"]) + yield self._add_service(s["url"], s["token"], s["id"]) self.as_yaml_files = [] # We assume there is only one database in these tests - database = self.hs.get_datastores().databases[0] + database = hs.get_datastores().databases[0] self.db_pool = database._db_pool self.engine = database.engine - db_config = self.hs.config.database.get_single_database() + db_config = hs.config.database.get_single_database() self.store = TestTransactionStore( - database, make_conn(db_config, self.engine, "test"), self.hs + database, make_conn(db_config, self.engine, "test"), hs ) - def _add_service(self, url, as_token, id) -> None: + def _add_service(self, url, as_token, id): as_yaml = { "url": url, "as_token": as_token, @@ -148,15 +145,13 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): outfile.write(yaml.dump(as_yaml)) self.as_yaml_files.append(as_token) - def _set_state( - self, id: str, state: ApplicationServiceState, txn: Optional[int] = None - ): + def _set_state(self, id, state, txn=None): return self.db_pool.runOperation( self.engine.convert_param_style( "INSERT INTO application_services_state(as_id, state, last_txn) " "VALUES(?,?,?)" ), - (id, state.value, txn), + (id, state, txn), ) def _insert_txn(self, as_id, txn_id, events): @@ -174,277 +169,234 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): "INSERT INTO application_services_state(as_id, last_txn, state) " "VALUES(?,?,?)" ), - (as_id, txn_id, ApplicationServiceState.UP.value), + (as_id, txn_id, ApplicationServiceState.UP), ) - def test_get_appservice_state_none( - self, - ) -> None: + @defer.inlineCallbacks + def test_get_appservice_state_none(self): service = Mock(id="999") - state = self.get_success(self.store.get_appservice_state(service)) + state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(None, state) - def test_get_appservice_state_up( - self, - ) -> None: - self.get_success( - self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) - ) + @defer.inlineCallbacks + def test_get_appservice_state_up(self): + yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) service = Mock(id=self.as_list[0]["id"]) - state = self.get_success( - defer.ensureDeferred(self.store.get_appservice_state(service)) - ) + state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(ApplicationServiceState.UP, state) - def test_get_appservice_state_down( - self, - ) -> None: - self.get_success( - self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) - ) - self.get_success( - self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN) - ) - self.get_success( - self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) - ) + @defer.inlineCallbacks + def test_get_appservice_state_down(self): + yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) + yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN) + yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) service = Mock(id=self.as_list[1]["id"]) - state = self.get_success(self.store.get_appservice_state(service)) + state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(ApplicationServiceState.DOWN, state) - def test_get_appservices_by_state_none( - self, - ) -> None: - services = self.get_success( + @defer.inlineCallbacks + def test_get_appservices_by_state_none(self): + services = yield defer.ensureDeferred( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(0, len(services)) - def test_set_appservices_state_down( - self, - ) -> None: + @defer.inlineCallbacks + def test_set_appservices_state_down(self): service = Mock(id=self.as_list[1]["id"]) - self.get_success( + yield defer.ensureDeferred( self.store.set_appservice_state(service, ApplicationServiceState.DOWN) ) - rows = self.get_success( - self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT as_id FROM application_services_state WHERE state=?" - ), - (ApplicationServiceState.DOWN.value,), - ) + rows = yield self.db_pool.runQuery( + self.engine.convert_param_style( + "SELECT as_id FROM application_services_state WHERE state=?" + ), + (ApplicationServiceState.DOWN,), ) self.assertEquals(service.id, rows[0][0]) - def test_set_appservices_state_multiple_up( - self, - ) -> None: + @defer.inlineCallbacks + def test_set_appservices_state_multiple_up(self): service = Mock(id=self.as_list[1]["id"]) - self.get_success( + yield defer.ensureDeferred( self.store.set_appservice_state(service, ApplicationServiceState.UP) ) - self.get_success( + yield defer.ensureDeferred( self.store.set_appservice_state(service, ApplicationServiceState.DOWN) ) - self.get_success( + yield defer.ensureDeferred( self.store.set_appservice_state(service, ApplicationServiceState.UP) ) - rows = self.get_success( - self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT as_id FROM application_services_state WHERE state=?" - ), - (ApplicationServiceState.UP.value,), - ) + rows = yield self.db_pool.runQuery( + self.engine.convert_param_style( + "SELECT as_id FROM application_services_state WHERE state=?" + ), + (ApplicationServiceState.UP,), ) self.assertEquals(service.id, rows[0][0]) - def test_create_appservice_txn_first( - self, - ) -> None: + @defer.inlineCallbacks + def test_create_appservice_txn_first(self): service = Mock(id=self.as_list[0]["id"]) - events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) - txn = self.get_success( - defer.ensureDeferred(self.store.create_appservice_txn(service, events, [])) + events = [Mock(event_id="e1"), Mock(event_id="e2")] + txn = yield defer.ensureDeferred( + self.store.create_appservice_txn(service, events, []) ) self.assertEquals(txn.id, 1) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) - def test_create_appservice_txn_older_last_txn( - self, - ) -> None: + @defer.inlineCallbacks + def test_create_appservice_txn_older_last_txn(self): service = Mock(id=self.as_list[0]["id"]) - events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) - self.get_success(self._set_last_txn(service.id, 9643)) # AS is falling behind - self.get_success(self._insert_txn(service.id, 9644, events)) - self.get_success(self._insert_txn(service.id, 9645, events)) - txn = self.get_success(self.store.create_appservice_txn(service, events, [])) + events = [Mock(event_id="e1"), Mock(event_id="e2")] + yield self._set_last_txn(service.id, 9643) # AS is falling behind + yield self._insert_txn(service.id, 9644, events) + yield self._insert_txn(service.id, 9645, events) + txn = yield defer.ensureDeferred( + self.store.create_appservice_txn(service, events, []) + ) self.assertEquals(txn.id, 9646) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) - def test_create_appservice_txn_up_to_date_last_txn( - self, - ) -> None: + @defer.inlineCallbacks + def test_create_appservice_txn_up_to_date_last_txn(self): service = Mock(id=self.as_list[0]["id"]) - events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) - self.get_success(self._set_last_txn(service.id, 9643)) - txn = self.get_success(self.store.create_appservice_txn(service, events, [])) + events = [Mock(event_id="e1"), Mock(event_id="e2")] + yield self._set_last_txn(service.id, 9643) + txn = yield defer.ensureDeferred( + self.store.create_appservice_txn(service, events, []) + ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) - def test_create_appservice_txn_up_fuzzing( - self, - ) -> None: + @defer.inlineCallbacks + def test_create_appservice_txn_up_fuzzing(self): service = Mock(id=self.as_list[0]["id"]) - events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) - self.get_success(self._set_last_txn(service.id, 9643)) + events = [Mock(event_id="e1"), Mock(event_id="e2")] + yield self._set_last_txn(service.id, 9643) # dump in rows with higher IDs to make sure the queries aren't wrong. - self.get_success(self._set_last_txn(self.as_list[1]["id"], 119643)) - self.get_success(self._set_last_txn(self.as_list[2]["id"], 9)) - self.get_success(self._set_last_txn(self.as_list[3]["id"], 9643)) - self.get_success(self._insert_txn(self.as_list[1]["id"], 119644, events)) - self.get_success(self._insert_txn(self.as_list[1]["id"], 119645, events)) - self.get_success(self._insert_txn(self.as_list[1]["id"], 119646, events)) - self.get_success(self._insert_txn(self.as_list[2]["id"], 10, events)) - self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events)) - - txn = self.get_success(self.store.create_appservice_txn(service, events, [])) + yield self._set_last_txn(self.as_list[1]["id"], 119643) + yield self._set_last_txn(self.as_list[2]["id"], 9) + yield self._set_last_txn(self.as_list[3]["id"], 9643) + yield self._insert_txn(self.as_list[1]["id"], 119644, events) + yield self._insert_txn(self.as_list[1]["id"], 119645, events) + yield self._insert_txn(self.as_list[1]["id"], 119646, events) + yield self._insert_txn(self.as_list[2]["id"], 10, events) + yield self._insert_txn(self.as_list[3]["id"], 9643, events) + + txn = yield defer.ensureDeferred( + self.store.create_appservice_txn(service, events, []) + ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) - def test_complete_appservice_txn_first_txn( - self, - ) -> None: + @defer.inlineCallbacks + def test_complete_appservice_txn_first_txn(self): service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] txn_id = 1 - self.get_success(self._insert_txn(service.id, txn_id, events)) - self.get_success( + yield self._insert_txn(service.id, txn_id, events) + yield defer.ensureDeferred( self.store.complete_appservice_txn(txn_id=txn_id, service=service) ) - res = self.get_success( - self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT last_txn FROM application_services_state WHERE as_id=?" - ), - (service.id,), - ) + res = yield self.db_pool.runQuery( + self.engine.convert_param_style( + "SELECT last_txn FROM application_services_state WHERE as_id=?" + ), + (service.id,), ) self.assertEquals(1, len(res)) self.assertEquals(txn_id, res[0][0]) - res = self.get_success( - self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT * FROM application_services_txns WHERE txn_id=?" - ), - (txn_id,), - ) + res = yield self.db_pool.runQuery( + self.engine.convert_param_style( + "SELECT * FROM application_services_txns WHERE txn_id=?" + ), + (txn_id,), ) self.assertEquals(0, len(res)) - def test_complete_appservice_txn_existing_in_state_table( - self, - ) -> None: + @defer.inlineCallbacks + def test_complete_appservice_txn_existing_in_state_table(self): service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] txn_id = 5 - self.get_success(self._set_last_txn(service.id, 4)) - self.get_success(self._insert_txn(service.id, txn_id, events)) - self.get_success( + yield self._set_last_txn(service.id, 4) + yield self._insert_txn(service.id, txn_id, events) + yield defer.ensureDeferred( self.store.complete_appservice_txn(txn_id=txn_id, service=service) ) - res = self.get_success( - self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT last_txn, state FROM application_services_state WHERE as_id=?" - ), - (service.id,), - ) + res = yield self.db_pool.runQuery( + self.engine.convert_param_style( + "SELECT last_txn, state FROM application_services_state WHERE as_id=?" + ), + (service.id,), ) self.assertEquals(1, len(res)) self.assertEquals(txn_id, res[0][0]) - self.assertEquals(ApplicationServiceState.UP.value, res[0][1]) - - res = self.get_success( - self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT * FROM application_services_txns WHERE txn_id=?" - ), - (txn_id,), - ) + self.assertEquals(ApplicationServiceState.UP, res[0][1]) + + res = yield self.db_pool.runQuery( + self.engine.convert_param_style( + "SELECT * FROM application_services_txns WHERE txn_id=?" + ), + (txn_id,), ) self.assertEquals(0, len(res)) - def test_get_oldest_unsent_txn_none( - self, - ) -> None: + @defer.inlineCallbacks + def test_get_oldest_unsent_txn_none(self): service = Mock(id=self.as_list[0]["id"]) - txn = self.get_success(self.store.get_oldest_unsent_txn(service)) + txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service)) self.assertEquals(None, txn) - def test_get_oldest_unsent_txn(self) -> None: + @defer.inlineCallbacks + def test_get_oldest_unsent_txn(self): service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] other_events = [Mock(event_id="e5"), Mock(event_id="e6")] # we aren't testing store._base stuff here, so mock this out - # (ignore needed because Mypy won't allow us to assign to a method otherwise) - self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) # type: ignore[assignment] + self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) - self.get_success(self._insert_txn(self.as_list[1]["id"], 9, other_events)) - self.get_success(self._insert_txn(service.id, 10, events)) - self.get_success(self._insert_txn(service.id, 11, other_events)) - self.get_success(self._insert_txn(service.id, 12, other_events)) + yield self._insert_txn(self.as_list[1]["id"], 9, other_events) + yield self._insert_txn(service.id, 10, events) + yield self._insert_txn(service.id, 11, other_events) + yield self._insert_txn(service.id, 12, other_events) - txn = self.get_success(self.store.get_oldest_unsent_txn(service)) + txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service)) self.assertEquals(service, txn.service) self.assertEquals(10, txn.id) self.assertEquals(events, txn.events) - def test_get_appservices_by_state_single( - self, - ) -> None: - self.get_success( - self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) - ) - self.get_success( - self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) - ) + @defer.inlineCallbacks + def test_get_appservices_by_state_single(self): + yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) + yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) - services = self.get_success( + services = yield defer.ensureDeferred( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(1, len(services)) self.assertEquals(self.as_list[0]["id"], services[0].id) - def test_get_appservices_by_state_multiple( - self, - ) -> None: - self.get_success( - self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) - ) - self.get_success( - self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) - ) - self.get_success( - self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) - ) - self.get_success( - self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP) - ) + @defer.inlineCallbacks + def test_get_appservices_by_state_multiple(self): + yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) + yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) + yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) + yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP) - services = self.get_success( + services = yield defer.ensureDeferred( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(2, len(services)) @@ -455,16 +407,16 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): - def prepare( - self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer - ) -> None: + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + return hs + + def prepare(self, hs, reactor, clock): self.service = Mock(id="foo") self.store = self.hs.get_datastore() - self.get_success( - self.store.set_appservice_state(self.service, ApplicationServiceState.UP) - ) + self.get_success(self.store.set_appservice_state(self.service, "up")) - def test_get_type_stream_id_for_appservice_no_value(self) -> None: + def test_get_type_stream_id_for_appservice_no_value(self): value = self.get_success( self.store.get_type_stream_id_for_appservice(self.service, "read_receipt") ) @@ -475,13 +427,13 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): ) self.assertEquals(value, 0) - def test_get_type_stream_id_for_appservice_invalid_type(self) -> None: + def test_get_type_stream_id_for_appservice_invalid_type(self): self.get_failure( self.store.get_type_stream_id_for_appservice(self.service, "foobar"), ValueError, ) - def test_set_type_stream_id_for_appservice(self) -> None: + def test_set_type_stream_id_for_appservice(self): read_receipt_value = 1024 self.get_success( self.store.set_type_stream_id_for_appservice( @@ -503,7 +455,7 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): ) self.assertEqual(result, read_receipt_value) - def test_set_type_stream_id_for_appservice_invalid_type(self) -> None: + def test_set_type_stream_id_for_appservice_invalid_type(self): self.get_failure( self.store.set_type_stream_id_for_appservice(self.service, "foobar", 1024), ValueError, @@ -512,12 +464,12 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): # required for ApplicationServiceTransactionStoreTestCase tests class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): - def __init__(self, database: DatabasePool, db_conn, hs) -> None: + def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) -class ApplicationServiceStoreConfigTestCase(unittest.HomeserverTestCase): - def _write_config(self, suffix, **kwargs) -> str: +class ApplicationServiceStoreConfigTestCase(unittest.TestCase): + def _write_config(self, suffix, **kwargs): vals = { "id": "id" + suffix, "url": "url" + suffix, @@ -533,33 +485,41 @@ class ApplicationServiceStoreConfigTestCase(unittest.HomeserverTestCase): f.write(yaml.dump(vals)) return path - def test_unique_works(self) -> None: + @defer.inlineCallbacks + def test_unique_works(self): f1 = self._write_config(suffix="1") f2 = self._write_config(suffix="2") - self.hs.config.appservice.app_service_config_files = [f1, f2] - self.hs.config.caches.event_cache_size = 1 + hs = yield setup_test_homeserver( + self.addCleanup, federation_sender=Mock(), federation_client=Mock() + ) + + hs.config.appservice.app_service_config_files = [f1, f2] + hs.config.caches.event_cache_size = 1 - database = self.hs.get_datastores().databases[0] + database = hs.get_datastores().databases[0] ApplicationServiceStore( - database, - make_conn(database._database_config, database.engine, "test"), - self.hs, + database, make_conn(database._database_config, database.engine, "test"), hs ) - def test_duplicate_ids(self) -> None: + @defer.inlineCallbacks + def test_duplicate_ids(self): f1 = self._write_config(id="id", suffix="1") f2 = self._write_config(id="id", suffix="2") - self.hs.config.appservice.app_service_config_files = [f1, f2] - self.hs.config.caches.event_cache_size = 1 + hs = yield setup_test_homeserver( + self.addCleanup, federation_sender=Mock(), federation_client=Mock() + ) + + hs.config.appservice.app_service_config_files = [f1, f2] + hs.config.caches.event_cache_size = 1 with self.assertRaises(ConfigError) as cm: - database = self.hs.get_datastores().databases[0] + database = hs.get_datastores().databases[0] ApplicationServiceStore( database, make_conn(database._database_config, database.engine, "test"), - self.hs, + hs, ) e = cm.exception @@ -567,19 +527,24 @@ class ApplicationServiceStoreConfigTestCase(unittest.HomeserverTestCase): self.assertIn(f2, str(e)) self.assertIn("id", str(e)) - def test_duplicate_as_tokens(self) -> None: + @defer.inlineCallbacks + def test_duplicate_as_tokens(self): f1 = self._write_config(as_token="as_token", suffix="1") f2 = self._write_config(as_token="as_token", suffix="2") - self.hs.config.appservice.app_service_config_files = [f1, f2] - self.hs.config.caches.event_cache_size = 1 + hs = yield setup_test_homeserver( + self.addCleanup, federation_sender=Mock(), federation_client=Mock() + ) + + hs.config.appservice.app_service_config_files = [f1, f2] + hs.config.caches.event_cache_size = 1 with self.assertRaises(ConfigError) as cm: - database = self.hs.get_datastores().databases[0] + database = hs.get_datastores().databases[0] ApplicationServiceStore( database, make_conn(database._database_config, database.engine, "test"), - self.hs, + hs, ) e = cm.exception diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index d77c001506..a5f5ebad41 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -1,26 +1,8 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Use backported mock for AsyncMock support on Python 3.6. -from mock import Mock - -from twisted.internet.defer import Deferred, ensureDeferred +from unittest.mock import Mock from synapse.storage.background_updates import BackgroundUpdater from tests import unittest -from tests.test_utils import make_awaitable class BackgroundUpdateTestCase(unittest.HomeserverTestCase): @@ -38,10 +20,10 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): def test_do_background_update(self): # the time we claim it takes to update one item when running the update - duration_ms = 10 + duration_ms = 4200 # the target runtime for each bg update - target_background_update_duration_ms = 100 + target_background_update_duration_ms = 5000000 store = self.hs.get_datastore() self.get_success( @@ -66,8 +48,10 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): self.update_handler.side_effect = update self.update_handler.reset_mock() res = self.get_success( - self.updates.do_next_background_update(False), - by=0.01, + self.updates.do_next_background_update( + target_background_update_duration_ms + ), + by=0.1, ) self.assertFalse(res) @@ -90,93 +74,16 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): self.update_handler.side_effect = update self.update_handler.reset_mock() - result = self.get_success(self.updates.do_next_background_update(False)) + result = self.get_success( + self.updates.do_next_background_update(target_background_update_duration_ms) + ) self.assertFalse(result) self.update_handler.assert_called_once() # third step: we don't expect to be called any more self.update_handler.reset_mock() - result = self.get_success(self.updates.do_next_background_update(False)) + result = self.get_success( + self.updates.do_next_background_update(target_background_update_duration_ms) + ) self.assertTrue(result) self.assertFalse(self.update_handler.called) - - -class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): - self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates - # the base test class should have run the real bg updates for us - self.assertTrue( - self.get_success(self.updates.has_completed_background_updates()) - ) - - self.update_deferred = Deferred() - self.update_handler = Mock(return_value=self.update_deferred) - self.updates.register_background_update_handler( - "test_update", self.update_handler - ) - - # Mock out the AsyncContextManager - self._update_ctx_manager = Mock(spec=["__aenter__", "__aexit__"]) - self._update_ctx_manager.__aenter__ = Mock( - return_value=make_awaitable(None), - ) - self._update_ctx_manager.__aexit__ = Mock(return_value=make_awaitable(None)) - - # Mock out the `update_handler` callback - self._on_update = Mock(return_value=self._update_ctx_manager) - - # Define a default batch size value that's not the same as the internal default - # value (100). - self._default_batch_size = 500 - - # Register the callbacks with more mocks - self.hs.get_module_api().register_background_update_controller_callbacks( - on_update=self._on_update, - min_batch_size=Mock(return_value=make_awaitable(self._default_batch_size)), - default_batch_size=Mock( - return_value=make_awaitable(self._default_batch_size), - ), - ) - - def test_controller(self): - store = self.hs.get_datastore() - self.get_success( - store.db_pool.simple_insert( - "background_updates", - values={"update_name": "test_update", "progress_json": "{}"}, - ) - ) - - # Set the return value for the context manager. - enter_defer = Deferred() - self._update_ctx_manager.__aenter__ = Mock(return_value=enter_defer) - - # Start the background update. - do_update_d = ensureDeferred(self.updates.do_next_background_update(True)) - - self.pump() - - # `run_update` should have been called, but the update handler won't be - # called until the `enter_defer` (returned by `__aenter__`) is resolved. - self._on_update.assert_called_once_with( - "test_update", - "master", - False, - ) - self.assertFalse(do_update_d.called) - self.assertFalse(self.update_deferred.called) - - # Resolving the `enter_defer` should call the update handler, which then - # blocks. - enter_defer.callback(100) - self.pump() - self.update_handler.assert_called_once_with({}, self._default_batch_size) - self.assertFalse(self.update_deferred.called) - self._update_ctx_manager.__aexit__.assert_not_called() - - # Resolving the update handler deferred should cause the - # `do_next_background_update` to finish and return - self.update_deferred.callback(100) - self.pump() - self._update_ctx_manager.__aexit__.assert_called() - self.get_success(do_update_d) diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index 7b7f6c349e..b31c5eb5ec 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -664,7 +664,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): ): iterations += 1 self.get_success( - self.store.db_pool.updates.do_next_background_update(False), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) # Ensure that we did actually take multiple iterations to process the @@ -723,7 +723,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): ): iterations += 1 self.get_success( - self.store.db_pool.updates.do_next_background_update(False), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) # Ensure that we did actually take multiple iterations to process the diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index f8d11bac4e..d2b7b89952 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -13,35 +13,42 @@ # limitations under the License. +from twisted.internet import defer + from synapse.types import UserID from tests import unittest +from tests.utils import setup_test_homeserver -class DataStoreTestCase(unittest.HomeserverTestCase): - def setUp(self) -> None: - super(DataStoreTestCase, self).setUp() +class DataStoreTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + hs = yield setup_test_homeserver(self.addCleanup) - self.store = self.hs.get_datastore() + self.store = hs.get_datastore() self.user = UserID.from_string("@abcde:test") self.displayname = "Frank" - def test_get_users_paginate(self) -> None: - self.get_success(self.store.register_user(self.user.to_string(), "pass")) - self.get_success(self.store.create_profile(self.user.localpart)) - self.get_success( + @defer.inlineCallbacks + def test_get_users_paginate(self): + yield defer.ensureDeferred( + self.store.register_user(self.user.to_string(), "pass") + ) + yield defer.ensureDeferred(self.store.create_profile(self.user.localpart)) + yield defer.ensureDeferred( self.store.set_profile_displayname(self.user.localpart, self.displayname) ) - users, total = self.get_success( + users, total = yield defer.ensureDeferred( self.store.get_users_paginate(0, 10, name="bc", guests=False) ) self.assertEquals(1, total) self.assertEquals(self.displayname, users.pop()["displayname"]) - users, total = self.get_success( + users, total = yield defer.ensureDeferred( self.store.get_users_paginate(0, 10, name="BC", guests=False) ) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 7f5b28aed8..37cf7bb232 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -23,7 +23,6 @@ from synapse.rest import admin from synapse.rest.client import login, register, room from synapse.server import HomeServer from synapse.storage import DataStore -from synapse.storage.background_updates import _BackgroundUpdateHandler from synapse.storage.roommember import ProfileInfo from synapse.util import Clock @@ -392,9 +391,7 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): with mock.patch.dict( self.store.db_pool.updates._background_update_handlers, - populate_user_directory_process_users=_BackgroundUpdateHandler( - mocked_process_users, - ), + populate_user_directory_process_users=mocked_process_users, ): self._purge_and_rebuild_user_dir() diff --git a/tests/test_visibility.py b/tests/test_visibility.py index e0b08d67d4..94b19788d7 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -13,30 +13,35 @@ # limitations under the License. import logging from typing import Optional +from unittest.mock import Mock + +from twisted.internet import defer +from twisted.internet.defer import succeed from synapse.api.room_versions import RoomVersions -from synapse.events import EventBase -from synapse.types import JsonDict +from synapse.events import FrozenEvent from synapse.visibility import filter_events_for_server -from tests import unittest -from tests.utils import create_room +import tests.unittest +from tests.utils import create_room, setup_test_homeserver logger = logging.getLogger(__name__) TEST_ROOM_ID = "!TEST:ROOM" -class FilterEventsForServerTestCase(unittest.HomeserverTestCase): - def setUp(self) -> None: - super(FilterEventsForServerTestCase, self).setUp() +class FilterEventsForServerTestCase(tests.unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver(self.addCleanup) self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() self.storage = self.hs.get_storage() - self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) + yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) - def test_filtering(self) -> None: + @defer.inlineCallbacks + def test_filtering(self): # # The events to be filtered consist of 10 membership events (it doesn't # really matter if they are joins or leaves, so let's make them joins). @@ -46,20 +51,18 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): # # before we do that, we persist some other events to act as state. - self.get_success(self._inject_visibility("@admin:hs", "joined")) + yield self.inject_visibility("@admin:hs", "joined") for i in range(0, 10): - self.get_success(self._inject_room_member("@resident%i:hs" % i)) + yield self.inject_room_member("@resident%i:hs" % i) events_to_filter = [] for i in range(0, 10): user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server") - evt = self.get_success( - self._inject_room_member(user, extra_content={"a": "b"}) - ) + evt = yield self.inject_room_member(user, extra_content={"a": "b"}) events_to_filter.append(evt) - filtered = self.get_success( + filtered = yield defer.ensureDeferred( filter_events_for_server(self.storage, "test_server", events_to_filter) ) @@ -72,31 +75,34 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) self.assertEqual(filtered[i].content["a"], "b") - def test_erased_user(self) -> None: + @defer.inlineCallbacks + def test_erased_user(self): # 4 message events, from erased and unerased users, with a membership # change in the middle of them. events_to_filter = [] - evt = self.get_success(self._inject_message("@unerased:local_hs")) + evt = yield self.inject_message("@unerased:local_hs") events_to_filter.append(evt) - evt = self.get_success(self._inject_message("@erased:local_hs")) + evt = yield self.inject_message("@erased:local_hs") events_to_filter.append(evt) - evt = self.get_success(self._inject_room_member("@joiner:remote_hs")) + evt = yield self.inject_room_member("@joiner:remote_hs") events_to_filter.append(evt) - evt = self.get_success(self._inject_message("@unerased:local_hs")) + evt = yield self.inject_message("@unerased:local_hs") events_to_filter.append(evt) - evt = self.get_success(self._inject_message("@erased:local_hs")) + evt = yield self.inject_message("@erased:local_hs") events_to_filter.append(evt) # the erasey user gets erased - self.get_success(self.hs.get_datastore().mark_user_erased("@erased:local_hs")) + yield defer.ensureDeferred( + self.hs.get_datastore().mark_user_erased("@erased:local_hs") + ) # ... and the filtering happens. - filtered = self.get_success( + filtered = yield defer.ensureDeferred( filter_events_for_server(self.storage, "test_server", events_to_filter) ) @@ -117,7 +123,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): for i in (1, 4): self.assertNotIn("body", filtered[i].content) - def _inject_visibility(self, user_id: str, visibility: str) -> EventBase: + @defer.inlineCallbacks + def inject_visibility(self, user_id, visibility): content = {"history_visibility": visibility} builder = self.event_builder_factory.for_room_version( RoomVersions.V1, @@ -130,18 +137,18 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, context = yield defer.ensureDeferred( self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event - def _inject_room_member( - self, - user_id: str, - membership: str = "join", - extra_content: Optional[JsonDict] = None, - ) -> EventBase: + @defer.inlineCallbacks + def inject_room_member( + self, user_id, membership="join", extra_content: Optional[dict] = None + ): content = {"membership": membership} content.update(extra_content or {}) builder = self.event_builder_factory.for_room_version( @@ -155,16 +162,17 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, context = yield defer.ensureDeferred( self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event - def _inject_message( - self, user_id: str, content: Optional[JsonDict] = None - ) -> EventBase: + @defer.inlineCallbacks + def inject_message(self, user_id, content=None): if content is None: content = {"body": "testytest", "msgtype": "m.text"} builder = self.event_builder_factory.for_room_version( @@ -177,9 +185,164 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, context = yield defer.ensureDeferred( self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event + + @defer.inlineCallbacks + def test_large_room(self): + # see what happens when we have a large room with hundreds of thousands + # of membership events + + # As above, the events to be filtered consist of 10 membership events, + # where one of them is for a user on the server we are filtering for. + + import cProfile + import pstats + import time + + # we stub out the store, because building up all that state the normal + # way is very slow. + test_store = _TestStore() + + # our initial state is 100000 membership events and one + # history_visibility event. + room_state = [] + + history_visibility_evt = FrozenEvent( + { + "event_id": "$history_vis", + "type": "m.room.history_visibility", + "sender": "@resident_user_0:test.com", + "state_key": "", + "room_id": TEST_ROOM_ID, + "content": {"history_visibility": "joined"}, + } + ) + room_state.append(history_visibility_evt) + test_store.add_event(history_visibility_evt) + + for i in range(0, 100000): + user = "@resident_user_%i:test.com" % (i,) + evt = FrozenEvent( + { + "event_id": "$res_event_%i" % (i,), + "type": "m.room.member", + "state_key": user, + "sender": user, + "room_id": TEST_ROOM_ID, + "content": {"membership": "join", "extra": "zzz,"}, + } + ) + room_state.append(evt) + test_store.add_event(evt) + + events_to_filter = [] + for i in range(0, 10): + user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server") + evt = FrozenEvent( + { + "event_id": "$evt%i" % (i,), + "type": "m.room.member", + "state_key": user, + "sender": user, + "room_id": TEST_ROOM_ID, + "content": {"membership": "join", "extra": "zzz"}, + } + ) + events_to_filter.append(evt) + room_state.append(evt) + + test_store.add_event(evt) + test_store.set_state_ids_for_event( + evt, {(e.type, e.state_key): e.event_id for e in room_state} + ) + + pr = cProfile.Profile() + pr.enable() + + logger.info("Starting filtering") + start = time.time() + + storage = Mock() + storage.main = test_store + storage.state = test_store + + filtered = yield defer.ensureDeferred( + filter_events_for_server(test_store, "test_server", events_to_filter) + ) + logger.info("Filtering took %f seconds", time.time() - start) + + pr.disable() + with open("filter_events_for_server.profile", "w+") as f: + ps = pstats.Stats(pr, stream=f).sort_stats("cumulative") + ps.print_stats() + + # the result should be 5 redacted events, and 5 unredacted events. + for i in range(0, 5): + self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) + self.assertNotIn("extra", filtered[i].content) + + for i in range(5, 10): + self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) + self.assertEqual(filtered[i].content["extra"], "zzz") + + test_large_room.skip = "Disabled by default because it's slow" + + +class _TestStore: + """Implements a few methods of the DataStore, so that we can test + filter_events_for_server + + """ + + def __init__(self): + # data for get_events: a map from event_id to event + self.events = {} + + # data for get_state_ids_for_events mock: a map from event_id to + # a map from (type_state_key) -> event_id for the state at that + # event + self.state_ids_for_events = {} + + def add_event(self, event): + self.events[event.event_id] = event + + def set_state_ids_for_event(self, event, state): + self.state_ids_for_events[event.event_id] = state + + def get_state_ids_for_events(self, events, types): + res = {} + include_memberships = False + for (type, state_key) in types: + if type == "m.room.history_visibility": + continue + if type != "m.room.member" or state_key is not None: + raise RuntimeError( + "Unimplemented: get_state_ids with type (%s, %s)" + % (type, state_key) + ) + include_memberships = True + + if include_memberships: + for event_id in events: + res[event_id] = self.state_ids_for_events[event_id] + + else: + k = ("m.room.history_visibility", "") + for event_id in events: + hve = self.state_ids_for_events[event_id][k] + res[event_id] = {k: hve} + + return succeed(res) + + def get_events(self, events): + return succeed({event_id: self.events[event_id] for event_id in events}) + + def are_users_erased(self, users): + return succeed({u: False for u in users}) diff --git a/tests/unittest.py b/tests/unittest.py index eea0903f05..165aafc574 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -331,16 +331,17 @@ class HomeserverTestCase(TestCase): time.sleep(0.01) def wait_for_background_updates(self) -> None: - """Block until all background database updates have completed. + """ + Block until all background database updates have completed. - Note that callers must ensure there's a store property created on the + Note that callers must ensure that's a store property created on the testcase. """ while not self.get_success( self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db_pool.updates.do_next_background_update(False), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) def make_homeserver(self, reactor, clock): @@ -499,7 +500,8 @@ class HomeserverTestCase(TestCase): async def run_bg_updates(): with LoggingContext("run_bg_updates"): - self.get_success(stor.db_pool.updates.run_background_updates(False)) + while not await stor.db_pool.updates.has_completed_background_updates(): + await stor.db_pool.updates.do_next_background_update(1) hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) stor = hs.get_datastore() diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 291644eb7d..6578f3411e 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -13,7 +13,6 @@ # limitations under the License. -from typing import List from unittest.mock import Mock from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries @@ -262,17 +261,6 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase): self.assertEquals(cache["key4"], [4]) self.assertEquals(cache["key5"], [5, 6]) - def test_zero_size_drop_from_cache(self) -> None: - """Test that `drop_from_cache` works correctly with 0-sized entries.""" - cache: LruCache[str, List[int]] = LruCache(5, size_callback=lambda x: 0) - cache["key1"] = [] - - self.assertEqual(len(cache), 0) - cache.cache["key1"].drop_from_cache() - self.assertIsNone( - cache.pop("key1"), "Cache entry should have been evicted but wasn't" - ) - class TimeEvictionTestCase(unittest.HomeserverTestCase): """Test that time based eviction works correctly.""" |