diff options
Diffstat (limited to 'tests/handlers/test_password_providers.py')
-rw-r--r-- | tests/handlers/test_password_providers.py | 144 |
1 files changed, 74 insertions, 70 deletions
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 75934b1707..0916de64f5 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -15,12 +15,13 @@ """Tests for the password_auth_provider interface""" from http import HTTPStatus -from typing import Any, Type, Union +from typing import Any, Dict, List, Optional, Type, Union from unittest.mock import Mock import synapse from synapse.api.constants import LoginType from synapse.api.errors import Codes +from synapse.handlers.account import AccountHandler from synapse.module_api import ModuleApi from synapse.rest.client import account, devices, login, logout, register from synapse.types import JsonDict, UserID @@ -44,13 +45,13 @@ class LegacyPasswordOnlyAuthProvider: """A legacy password_provider which only implements `check_password`.""" @staticmethod - def parse_config(self): + def parse_config(config: JsonDict) -> None: pass - def __init__(self, config, account_handler): + def __init__(self, config: None, account_handler: AccountHandler): pass - def check_password(self, *args): + def check_password(self, *args: str) -> Mock: return mock_password_provider.check_password(*args) @@ -58,16 +59,16 @@ class LegacyCustomAuthProvider: """A legacy password_provider which implements a custom login type.""" @staticmethod - def parse_config(self): + def parse_config(config: JsonDict) -> None: pass - def __init__(self, config, account_handler): + def __init__(self, config: None, account_handler: AccountHandler): pass - def get_supported_login_types(self): + def get_supported_login_types(self) -> Dict[str, List[str]]: return {"test.login_type": ["test_field"]} - def check_auth(self, *args): + def check_auth(self, *args: str) -> Mock: return mock_password_provider.check_auth(*args) @@ -75,15 +76,15 @@ class CustomAuthProvider: """A module which registers password_auth_provider callbacks for a custom login type.""" @staticmethod - def parse_config(self): + def parse_config(config: JsonDict) -> None: pass - def __init__(self, config, api: ModuleApi): + def __init__(self, config: None, api: ModuleApi): api.register_password_auth_provider_callbacks( auth_checkers={("test.login_type", ("test_field",)): self.check_auth} ) - def check_auth(self, *args): + def check_auth(self, *args: Any) -> Mock: return mock_password_provider.check_auth(*args) @@ -92,16 +93,16 @@ class LegacyPasswordCustomAuthProvider: as a custom type.""" @staticmethod - def parse_config(self): + def parse_config(config: JsonDict) -> None: pass - def __init__(self, config, account_handler): + def __init__(self, config: None, account_handler: AccountHandler): pass - def get_supported_login_types(self): + def get_supported_login_types(self) -> Dict[str, List[str]]: return {"m.login.password": ["password"], "test.login_type": ["test_field"]} - def check_auth(self, *args): + def check_auth(self, *args: str) -> Mock: return mock_password_provider.check_auth(*args) @@ -110,10 +111,10 @@ class PasswordCustomAuthProvider: as well as a password login""" @staticmethod - def parse_config(self): + def parse_config(config: JsonDict) -> None: pass - def __init__(self, config, api: ModuleApi): + def __init__(self, config: None, api: ModuleApi): api.register_password_auth_provider_callbacks( auth_checkers={ ("test.login_type", ("test_field",)): self.check_auth, @@ -121,10 +122,10 @@ class PasswordCustomAuthProvider: } ) - def check_auth(self, *args): + def check_auth(self, *args: Any) -> Mock: return mock_password_provider.check_auth(*args) - def check_pass(self, *args): + def check_pass(self, *args: str) -> Mock: return mock_password_provider.check_password(*args) @@ -161,16 +162,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): CALLBACK_USERNAME = "get_username_for_registration" CALLBACK_DISPLAYNAME = "get_displayname_for_registration" - def setUp(self): + def setUp(self) -> None: # we use a global mock device, so make sure we are starting with a clean slate mock_password_provider.reset_mock() super().setUp() @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) - def test_password_only_auth_progiver_login_legacy(self): + def test_password_only_auth_progiver_login_legacy(self) -> None: self.password_only_auth_provider_login_test_body() - def password_only_auth_provider_login_test_body(self): + def password_only_auth_provider_login_test_body(self) -> None: # login flows should only have m.login.password flows = self._get_login_flows() self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) @@ -201,10 +202,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): ) @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) - def test_password_only_auth_provider_ui_auth_legacy(self): + def test_password_only_auth_provider_ui_auth_legacy(self) -> None: self.password_only_auth_provider_ui_auth_test_body() - def password_only_auth_provider_ui_auth_test_body(self): + def password_only_auth_provider_ui_auth_test_body(self) -> None: """UI Auth should delegate correctly to the password provider""" # create the user, otherwise access doesn't work @@ -238,10 +239,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.check_password.assert_called_once_with("@u:test", "p") @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) - def test_local_user_fallback_login_legacy(self): + def test_local_user_fallback_login_legacy(self) -> None: self.local_user_fallback_login_test_body() - def local_user_fallback_login_test_body(self): + def local_user_fallback_login_test_body(self) -> None: """rejected login should fall back to local db""" self.register_user("localuser", "localpass") @@ -255,10 +256,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual("@localuser:test", channel.json_body["user_id"]) @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) - def test_local_user_fallback_ui_auth_legacy(self): + def test_local_user_fallback_ui_auth_legacy(self) -> None: self.local_user_fallback_ui_auth_test_body() - def local_user_fallback_ui_auth_test_body(self): + def local_user_fallback_ui_auth_test_body(self) -> None: """rejected login should fall back to local db""" self.register_user("localuser", "localpass") @@ -298,10 +299,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"localdb_enabled": False}, } ) - def test_no_local_user_fallback_login_legacy(self): + def test_no_local_user_fallback_login_legacy(self) -> None: self.no_local_user_fallback_login_test_body() - def no_local_user_fallback_login_test_body(self): + def no_local_user_fallback_login_test_body(self) -> None: """localdb_enabled can block login with the local password""" self.register_user("localuser", "localpass") @@ -320,10 +321,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"localdb_enabled": False}, } ) - def test_no_local_user_fallback_ui_auth_legacy(self): + def test_no_local_user_fallback_ui_auth_legacy(self) -> None: self.no_local_user_fallback_ui_auth_test_body() - def no_local_user_fallback_ui_auth_test_body(self): + def no_local_user_fallback_ui_auth_test_body(self) -> None: """localdb_enabled can block ui auth with the local password""" self.register_user("localuser", "localpass") @@ -361,10 +362,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"enabled": False}, } ) - def test_password_auth_disabled_legacy(self): + def test_password_auth_disabled_legacy(self) -> None: self.password_auth_disabled_test_body() - def password_auth_disabled_test_body(self): + def password_auth_disabled_test_body(self) -> None: """password auth doesn't work if it's disabled across the board""" # login flows should be empty flows = self._get_login_flows() @@ -376,14 +377,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.check_password.assert_not_called() @override_config(legacy_providers_config(LegacyCustomAuthProvider)) - def test_custom_auth_provider_login_legacy(self): + def test_custom_auth_provider_login_legacy(self) -> None: self.custom_auth_provider_login_test_body() @override_config(providers_config(CustomAuthProvider)) - def test_custom_auth_provider_login(self): + def test_custom_auth_provider_login(self) -> None: self.custom_auth_provider_login_test_body() - def custom_auth_provider_login_test_body(self): + def custom_auth_provider_login_test_body(self) -> None: # login flows should have the custom flow and m.login.password, since we # haven't disabled local password lookup. # (password must come first, because reasons) @@ -424,14 +425,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): ) @override_config(legacy_providers_config(LegacyCustomAuthProvider)) - def test_custom_auth_provider_ui_auth_legacy(self): + def test_custom_auth_provider_ui_auth_legacy(self) -> None: self.custom_auth_provider_ui_auth_test_body() @override_config(providers_config(CustomAuthProvider)) - def test_custom_auth_provider_ui_auth(self): + def test_custom_auth_provider_ui_auth(self) -> None: self.custom_auth_provider_ui_auth_test_body() - def custom_auth_provider_ui_auth_test_body(self): + def custom_auth_provider_ui_auth_test_body(self) -> None: # register the user and log in twice, to get two devices self.register_user("localuser", "localpass") tok1 = self.login("localuser", "localpass") @@ -486,14 +487,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): ) @override_config(legacy_providers_config(LegacyCustomAuthProvider)) - def test_custom_auth_provider_callback_legacy(self): + def test_custom_auth_provider_callback_legacy(self) -> None: self.custom_auth_provider_callback_test_body() @override_config(providers_config(CustomAuthProvider)) - def test_custom_auth_provider_callback(self): + def test_custom_auth_provider_callback(self) -> None: self.custom_auth_provider_callback_test_body() - def custom_auth_provider_callback_test_body(self): + def custom_auth_provider_callback_test_body(self) -> None: callback = Mock(return_value=make_awaitable(None)) mock_password_provider.check_auth.return_value = make_awaitable( @@ -521,16 +522,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"enabled": False}, } ) - def test_custom_auth_password_disabled_legacy(self): + def test_custom_auth_password_disabled_legacy(self) -> None: self.custom_auth_password_disabled_test_body() @override_config( {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}} ) - def test_custom_auth_password_disabled(self): + def test_custom_auth_password_disabled(self) -> None: self.custom_auth_password_disabled_test_body() - def custom_auth_password_disabled_test_body(self): + def custom_auth_password_disabled_test_body(self) -> None: """Test login with a custom auth provider where password login is disabled""" self.register_user("localuser", "localpass") @@ -548,7 +549,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"enabled": False, "localdb_enabled": False}, } ) - def test_custom_auth_password_disabled_localdb_enabled_legacy(self): + def test_custom_auth_password_disabled_localdb_enabled_legacy(self) -> None: self.custom_auth_password_disabled_localdb_enabled_test_body() @override_config( @@ -557,10 +558,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"enabled": False, "localdb_enabled": False}, } ) - def test_custom_auth_password_disabled_localdb_enabled(self): + def test_custom_auth_password_disabled_localdb_enabled(self) -> None: self.custom_auth_password_disabled_localdb_enabled_test_body() - def custom_auth_password_disabled_localdb_enabled_test_body(self): + def custom_auth_password_disabled_localdb_enabled_test_body(self) -> None: """Check the localdb_enabled == enabled == False Regression test for https://github.com/matrix-org/synapse/issues/8914: check @@ -583,7 +584,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"enabled": False}, } ) - def test_password_custom_auth_password_disabled_login_legacy(self): + def test_password_custom_auth_password_disabled_login_legacy(self) -> None: self.password_custom_auth_password_disabled_login_test_body() @override_config( @@ -592,10 +593,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"enabled": False}, } ) - def test_password_custom_auth_password_disabled_login(self): + def test_password_custom_auth_password_disabled_login(self) -> None: self.password_custom_auth_password_disabled_login_test_body() - def password_custom_auth_password_disabled_login_test_body(self): + def password_custom_auth_password_disabled_login_test_body(self) -> None: """log in with a custom auth provider which implements password, but password login is disabled""" self.register_user("localuser", "localpass") @@ -615,7 +616,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"enabled": False}, } ) - def test_password_custom_auth_password_disabled_ui_auth_legacy(self): + def test_password_custom_auth_password_disabled_ui_auth_legacy(self) -> None: self.password_custom_auth_password_disabled_ui_auth_test_body() @override_config( @@ -624,10 +625,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"enabled": False}, } ) - def test_password_custom_auth_password_disabled_ui_auth(self): + def test_password_custom_auth_password_disabled_ui_auth(self) -> None: self.password_custom_auth_password_disabled_ui_auth_test_body() - def password_custom_auth_password_disabled_ui_auth_test_body(self): + def password_custom_auth_password_disabled_ui_auth_test_body(self) -> None: """UI Auth with a custom auth provider which implements password, but password login is disabled""" # register the user and log in twice via the test login type to get two devices, @@ -689,7 +690,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"localdb_enabled": False}, } ) - def test_custom_auth_no_local_user_fallback_legacy(self): + def test_custom_auth_no_local_user_fallback_legacy(self) -> None: self.custom_auth_no_local_user_fallback_test_body() @override_config( @@ -698,10 +699,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"localdb_enabled": False}, } ) - def test_custom_auth_no_local_user_fallback(self): + def test_custom_auth_no_local_user_fallback(self) -> None: self.custom_auth_no_local_user_fallback_test_body() - def custom_auth_no_local_user_fallback_test_body(self): + def custom_auth_no_local_user_fallback_test_body(self) -> None: """Test login with a custom auth provider where the local db is disabled""" self.register_user("localuser", "localpass") @@ -713,14 +714,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): channel = self._send_password_login("localuser", "localpass") self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) - def test_on_logged_out(self): + def test_on_logged_out(self) -> None: """Tests that the on_logged_out callback is called when the user logs out.""" self.register_user("rin", "password") tok = self.login("rin", "password") self.called = False - async def on_logged_out(user_id, device_id, access_token): + async def on_logged_out( + user_id: str, device_id: Optional[str], access_token: str + ) -> None: self.called = True on_logged_out = Mock(side_effect=on_logged_out) @@ -738,7 +741,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): on_logged_out.assert_called_once() self.assertTrue(self.called) - def test_username(self): + def test_username(self) -> None: """Tests that the get_username_for_registration callback can define the username of a user when registering. """ @@ -763,7 +766,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mxid = channel.json_body["user_id"] self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo") - def test_username_uia(self): + def test_username_uia(self) -> None: """Tests that the get_username_for_registration callback is only called at the end of the UIA flow. """ @@ -782,7 +785,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # Set some email configuration so the test doesn't fail because of its absence. @override_config({"email": {"notif_from": "noreply@test"}}) - def test_3pid_allowed(self): + def test_3pid_allowed(self) -> None: """Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind the 3PID. Also checks that the module is passed a boolean indicating whether the @@ -791,7 +794,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self._test_3pid_allowed("rin", False) self._test_3pid_allowed("kitay", True) - def test_displayname(self): + def test_displayname(self) -> None: """Tests that the get_displayname_for_registration callback can define the display name of a user when registering. """ @@ -820,7 +823,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(display_name, username + "-foo") - def test_displayname_uia(self): + def test_displayname_uia(self) -> None: """Tests that the get_displayname_for_registration callback is only called at the end of the UIA flow. """ @@ -841,7 +844,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # Check that the callback has been called. m.assert_called_once() - def _test_3pid_allowed(self, username: str, registration: bool): + def _test_3pid_allowed(self, username: str, registration: bool) -> None: """Tests that the "is_3pid_allowed" module callback is called correctly, using either /register or /account URLs depending on the arguments. @@ -907,7 +910,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): client is trying to register. """ - async def callback(uia_results, params): + async def callback(uia_results: JsonDict, params: JsonDict) -> str: self.assertIn(LoginType.DUMMY, uia_results) username = params["username"] return username + "-foo" @@ -950,12 +953,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): def _send_password_login(self, user: str, password: str) -> FakeChannel: return self._send_login(type="m.login.password", user=user, password=password) - def _send_login(self, type, user, **params) -> FakeChannel: - params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type}) + def _send_login(self, type: str, user: str, **extra_params: str) -> FakeChannel: + params = {"identifier": {"type": "m.id.user", "user": user}, "type": type} + params.update(extra_params) channel = self.make_request("POST", "/_matrix/client/r0/login", params) return channel - def _start_delete_device_session(self, access_token, device_id) -> str: + def _start_delete_device_session(self, access_token: str, device_id: str) -> str: """Make an initial delete device request, and return the UI Auth session ID""" channel = self._delete_device(access_token, device_id) self.assertEqual(channel.code, 401) |