summary refs log tree commit diff
path: root/tests/handlers/test_password_providers.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_password_providers.py')
-rw-r--r--tests/handlers/test_password_providers.py144
1 files changed, 74 insertions, 70 deletions
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 75934b1707..0916de64f5 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -15,12 +15,13 @@
 """Tests for the password_auth_provider interface"""
 
 from http import HTTPStatus
-from typing import Any, Type, Union
+from typing import Any, Dict, List, Optional, Type, Union
 from unittest.mock import Mock
 
 import synapse
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes
+from synapse.handlers.account import AccountHandler
 from synapse.module_api import ModuleApi
 from synapse.rest.client import account, devices, login, logout, register
 from synapse.types import JsonDict, UserID
@@ -44,13 +45,13 @@ class LegacyPasswordOnlyAuthProvider:
     """A legacy password_provider which only implements `check_password`."""
 
     @staticmethod
-    def parse_config(self):
+    def parse_config(config: JsonDict) -> None:
         pass
 
-    def __init__(self, config, account_handler):
+    def __init__(self, config: None, account_handler: AccountHandler):
         pass
 
-    def check_password(self, *args):
+    def check_password(self, *args: str) -> Mock:
         return mock_password_provider.check_password(*args)
 
 
@@ -58,16 +59,16 @@ class LegacyCustomAuthProvider:
     """A legacy password_provider which implements a custom login type."""
 
     @staticmethod
-    def parse_config(self):
+    def parse_config(config: JsonDict) -> None:
         pass
 
-    def __init__(self, config, account_handler):
+    def __init__(self, config: None, account_handler: AccountHandler):
         pass
 
-    def get_supported_login_types(self):
+    def get_supported_login_types(self) -> Dict[str, List[str]]:
         return {"test.login_type": ["test_field"]}
 
-    def check_auth(self, *args):
+    def check_auth(self, *args: str) -> Mock:
         return mock_password_provider.check_auth(*args)
 
 
@@ -75,15 +76,15 @@ class CustomAuthProvider:
     """A module which registers password_auth_provider callbacks for a custom login type."""
 
     @staticmethod
-    def parse_config(self):
+    def parse_config(config: JsonDict) -> None:
         pass
 
-    def __init__(self, config, api: ModuleApi):
+    def __init__(self, config: None, api: ModuleApi):
         api.register_password_auth_provider_callbacks(
             auth_checkers={("test.login_type", ("test_field",)): self.check_auth}
         )
 
-    def check_auth(self, *args):
+    def check_auth(self, *args: Any) -> Mock:
         return mock_password_provider.check_auth(*args)
 
 
@@ -92,16 +93,16 @@ class LegacyPasswordCustomAuthProvider:
     as a custom type."""
 
     @staticmethod
-    def parse_config(self):
+    def parse_config(config: JsonDict) -> None:
         pass
 
-    def __init__(self, config, account_handler):
+    def __init__(self, config: None, account_handler: AccountHandler):
         pass
 
-    def get_supported_login_types(self):
+    def get_supported_login_types(self) -> Dict[str, List[str]]:
         return {"m.login.password": ["password"], "test.login_type": ["test_field"]}
 
-    def check_auth(self, *args):
+    def check_auth(self, *args: str) -> Mock:
         return mock_password_provider.check_auth(*args)
 
 
@@ -110,10 +111,10 @@ class PasswordCustomAuthProvider:
     as well as a password login"""
 
     @staticmethod
-    def parse_config(self):
+    def parse_config(config: JsonDict) -> None:
         pass
 
-    def __init__(self, config, api: ModuleApi):
+    def __init__(self, config: None, api: ModuleApi):
         api.register_password_auth_provider_callbacks(
             auth_checkers={
                 ("test.login_type", ("test_field",)): self.check_auth,
@@ -121,10 +122,10 @@ class PasswordCustomAuthProvider:
             }
         )
 
-    def check_auth(self, *args):
+    def check_auth(self, *args: Any) -> Mock:
         return mock_password_provider.check_auth(*args)
 
-    def check_pass(self, *args):
+    def check_pass(self, *args: str) -> Mock:
         return mock_password_provider.check_password(*args)
 
 
@@ -161,16 +162,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
     CALLBACK_USERNAME = "get_username_for_registration"
     CALLBACK_DISPLAYNAME = "get_displayname_for_registration"
 
-    def setUp(self):
+    def setUp(self) -> None:
         # we use a global mock device, so make sure we are starting with a clean slate
         mock_password_provider.reset_mock()
         super().setUp()
 
     @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
-    def test_password_only_auth_progiver_login_legacy(self):
+    def test_password_only_auth_progiver_login_legacy(self) -> None:
         self.password_only_auth_provider_login_test_body()
 
-    def password_only_auth_provider_login_test_body(self):
+    def password_only_auth_provider_login_test_body(self) -> None:
         # login flows should only have m.login.password
         flows = self._get_login_flows()
         self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
@@ -201,10 +202,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         )
 
     @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
-    def test_password_only_auth_provider_ui_auth_legacy(self):
+    def test_password_only_auth_provider_ui_auth_legacy(self) -> None:
         self.password_only_auth_provider_ui_auth_test_body()
 
-    def password_only_auth_provider_ui_auth_test_body(self):
+    def password_only_auth_provider_ui_auth_test_body(self) -> None:
         """UI Auth should delegate correctly to the password provider"""
 
         # create the user, otherwise access doesn't work
@@ -238,10 +239,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
 
     @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
-    def test_local_user_fallback_login_legacy(self):
+    def test_local_user_fallback_login_legacy(self) -> None:
         self.local_user_fallback_login_test_body()
 
-    def local_user_fallback_login_test_body(self):
+    def local_user_fallback_login_test_body(self) -> None:
         """rejected login should fall back to local db"""
         self.register_user("localuser", "localpass")
 
@@ -255,10 +256,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual("@localuser:test", channel.json_body["user_id"])
 
     @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
-    def test_local_user_fallback_ui_auth_legacy(self):
+    def test_local_user_fallback_ui_auth_legacy(self) -> None:
         self.local_user_fallback_ui_auth_test_body()
 
-    def local_user_fallback_ui_auth_test_body(self):
+    def local_user_fallback_ui_auth_test_body(self) -> None:
         """rejected login should fall back to local db"""
         self.register_user("localuser", "localpass")
 
@@ -298,10 +299,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             "password_config": {"localdb_enabled": False},
         }
     )
-    def test_no_local_user_fallback_login_legacy(self):
+    def test_no_local_user_fallback_login_legacy(self) -> None:
         self.no_local_user_fallback_login_test_body()
 
-    def no_local_user_fallback_login_test_body(self):
+    def no_local_user_fallback_login_test_body(self) -> None:
         """localdb_enabled can block login with the local password"""
         self.register_user("localuser", "localpass")
 
@@ -320,10 +321,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             "password_config": {"localdb_enabled": False},
         }
     )
-    def test_no_local_user_fallback_ui_auth_legacy(self):
+    def test_no_local_user_fallback_ui_auth_legacy(self) -> None:
         self.no_local_user_fallback_ui_auth_test_body()
 
-    def no_local_user_fallback_ui_auth_test_body(self):
+    def no_local_user_fallback_ui_auth_test_body(self) -> None:
         """localdb_enabled can block ui auth with the local password"""
         self.register_user("localuser", "localpass")
 
@@ -361,10 +362,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             "password_config": {"enabled": False},
         }
     )
-    def test_password_auth_disabled_legacy(self):
+    def test_password_auth_disabled_legacy(self) -> None:
         self.password_auth_disabled_test_body()
 
-    def password_auth_disabled_test_body(self):
+    def password_auth_disabled_test_body(self) -> None:
         """password auth doesn't work if it's disabled across the board"""
         # login flows should be empty
         flows = self._get_login_flows()
@@ -376,14 +377,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.check_password.assert_not_called()
 
     @override_config(legacy_providers_config(LegacyCustomAuthProvider))
-    def test_custom_auth_provider_login_legacy(self):
+    def test_custom_auth_provider_login_legacy(self) -> None:
         self.custom_auth_provider_login_test_body()
 
     @override_config(providers_config(CustomAuthProvider))
-    def test_custom_auth_provider_login(self):
+    def test_custom_auth_provider_login(self) -> None:
         self.custom_auth_provider_login_test_body()
 
-    def custom_auth_provider_login_test_body(self):
+    def custom_auth_provider_login_test_body(self) -> None:
         # login flows should have the custom flow and m.login.password, since we
         # haven't disabled local password lookup.
         # (password must come first, because reasons)
@@ -424,14 +425,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         )
 
     @override_config(legacy_providers_config(LegacyCustomAuthProvider))
-    def test_custom_auth_provider_ui_auth_legacy(self):
+    def test_custom_auth_provider_ui_auth_legacy(self) -> None:
         self.custom_auth_provider_ui_auth_test_body()
 
     @override_config(providers_config(CustomAuthProvider))
-    def test_custom_auth_provider_ui_auth(self):
+    def test_custom_auth_provider_ui_auth(self) -> None:
         self.custom_auth_provider_ui_auth_test_body()
 
-    def custom_auth_provider_ui_auth_test_body(self):
+    def custom_auth_provider_ui_auth_test_body(self) -> None:
         # register the user and log in twice, to get two devices
         self.register_user("localuser", "localpass")
         tok1 = self.login("localuser", "localpass")
@@ -486,14 +487,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         )
 
     @override_config(legacy_providers_config(LegacyCustomAuthProvider))
-    def test_custom_auth_provider_callback_legacy(self):
+    def test_custom_auth_provider_callback_legacy(self) -> None:
         self.custom_auth_provider_callback_test_body()
 
     @override_config(providers_config(CustomAuthProvider))
-    def test_custom_auth_provider_callback(self):
+    def test_custom_auth_provider_callback(self) -> None:
         self.custom_auth_provider_callback_test_body()
 
-    def custom_auth_provider_callback_test_body(self):
+    def custom_auth_provider_callback_test_body(self) -> None:
         callback = Mock(return_value=make_awaitable(None))
 
         mock_password_provider.check_auth.return_value = make_awaitable(
@@ -521,16 +522,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             "password_config": {"enabled": False},
         }
     )
-    def test_custom_auth_password_disabled_legacy(self):
+    def test_custom_auth_password_disabled_legacy(self) -> None:
         self.custom_auth_password_disabled_test_body()
 
     @override_config(
         {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
     )
-    def test_custom_auth_password_disabled(self):
+    def test_custom_auth_password_disabled(self) -> None:
         self.custom_auth_password_disabled_test_body()
 
-    def custom_auth_password_disabled_test_body(self):
+    def custom_auth_password_disabled_test_body(self) -> None:
         """Test login with a custom auth provider where password login is disabled"""
         self.register_user("localuser", "localpass")
 
@@ -548,7 +549,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             "password_config": {"enabled": False, "localdb_enabled": False},
         }
     )
-    def test_custom_auth_password_disabled_localdb_enabled_legacy(self):
+    def test_custom_auth_password_disabled_localdb_enabled_legacy(self) -> None:
         self.custom_auth_password_disabled_localdb_enabled_test_body()
 
     @override_config(
@@ -557,10 +558,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             "password_config": {"enabled": False, "localdb_enabled": False},
         }
     )
-    def test_custom_auth_password_disabled_localdb_enabled(self):
+    def test_custom_auth_password_disabled_localdb_enabled(self) -> None:
         self.custom_auth_password_disabled_localdb_enabled_test_body()
 
-    def custom_auth_password_disabled_localdb_enabled_test_body(self):
+    def custom_auth_password_disabled_localdb_enabled_test_body(self) -> None:
         """Check the localdb_enabled == enabled == False
 
         Regression test for https://github.com/matrix-org/synapse/issues/8914: check
@@ -583,7 +584,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             "password_config": {"enabled": False},
         }
     )
-    def test_password_custom_auth_password_disabled_login_legacy(self):
+    def test_password_custom_auth_password_disabled_login_legacy(self) -> None:
         self.password_custom_auth_password_disabled_login_test_body()
 
     @override_config(
@@ -592,10 +593,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             "password_config": {"enabled": False},
         }
     )
-    def test_password_custom_auth_password_disabled_login(self):
+    def test_password_custom_auth_password_disabled_login(self) -> None:
         self.password_custom_auth_password_disabled_login_test_body()
 
-    def password_custom_auth_password_disabled_login_test_body(self):
+    def password_custom_auth_password_disabled_login_test_body(self) -> None:
         """log in with a custom auth provider which implements password, but password
         login is disabled"""
         self.register_user("localuser", "localpass")
@@ -615,7 +616,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             "password_config": {"enabled": False},
         }
     )
-    def test_password_custom_auth_password_disabled_ui_auth_legacy(self):
+    def test_password_custom_auth_password_disabled_ui_auth_legacy(self) -> None:
         self.password_custom_auth_password_disabled_ui_auth_test_body()
 
     @override_config(
@@ -624,10 +625,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             "password_config": {"enabled": False},
         }
     )
-    def test_password_custom_auth_password_disabled_ui_auth(self):
+    def test_password_custom_auth_password_disabled_ui_auth(self) -> None:
         self.password_custom_auth_password_disabled_ui_auth_test_body()
 
-    def password_custom_auth_password_disabled_ui_auth_test_body(self):
+    def password_custom_auth_password_disabled_ui_auth_test_body(self) -> None:
         """UI Auth with a custom auth provider which implements password, but password
         login is disabled"""
         # register the user and log in twice via the test login type to get two devices,
@@ -689,7 +690,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             "password_config": {"localdb_enabled": False},
         }
     )
-    def test_custom_auth_no_local_user_fallback_legacy(self):
+    def test_custom_auth_no_local_user_fallback_legacy(self) -> None:
         self.custom_auth_no_local_user_fallback_test_body()
 
     @override_config(
@@ -698,10 +699,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             "password_config": {"localdb_enabled": False},
         }
     )
-    def test_custom_auth_no_local_user_fallback(self):
+    def test_custom_auth_no_local_user_fallback(self) -> None:
         self.custom_auth_no_local_user_fallback_test_body()
 
-    def custom_auth_no_local_user_fallback_test_body(self):
+    def custom_auth_no_local_user_fallback_test_body(self) -> None:
         """Test login with a custom auth provider where the local db is disabled"""
         self.register_user("localuser", "localpass")
 
@@ -713,14 +714,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         channel = self._send_password_login("localuser", "localpass")
         self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
 
-    def test_on_logged_out(self):
+    def test_on_logged_out(self) -> None:
         """Tests that the on_logged_out callback is called when the user logs out."""
         self.register_user("rin", "password")
         tok = self.login("rin", "password")
 
         self.called = False
 
-        async def on_logged_out(user_id, device_id, access_token):
+        async def on_logged_out(
+            user_id: str, device_id: Optional[str], access_token: str
+        ) -> None:
             self.called = True
 
         on_logged_out = Mock(side_effect=on_logged_out)
@@ -738,7 +741,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         on_logged_out.assert_called_once()
         self.assertTrue(self.called)
 
-    def test_username(self):
+    def test_username(self) -> None:
         """Tests that the get_username_for_registration callback can define the username
         of a user when registering.
         """
@@ -763,7 +766,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mxid = channel.json_body["user_id"]
         self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
 
-    def test_username_uia(self):
+    def test_username_uia(self) -> None:
         """Tests that the get_username_for_registration callback is only called at the
         end of the UIA flow.
         """
@@ -782,7 +785,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
     # Set some email configuration so the test doesn't fail because of its absence.
     @override_config({"email": {"notif_from": "noreply@test"}})
-    def test_3pid_allowed(self):
+    def test_3pid_allowed(self) -> None:
         """Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse
         to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind
         the 3PID. Also checks that the module is passed a boolean indicating whether the
@@ -791,7 +794,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self._test_3pid_allowed("rin", False)
         self._test_3pid_allowed("kitay", True)
 
-    def test_displayname(self):
+    def test_displayname(self) -> None:
         """Tests that the get_displayname_for_registration callback can define the
         display name of a user when registering.
         """
@@ -820,7 +823,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
         self.assertEqual(display_name, username + "-foo")
 
-    def test_displayname_uia(self):
+    def test_displayname_uia(self) -> None:
         """Tests that the get_displayname_for_registration callback is only called at the
         end of the UIA flow.
         """
@@ -841,7 +844,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         # Check that the callback has been called.
         m.assert_called_once()
 
-    def _test_3pid_allowed(self, username: str, registration: bool):
+    def _test_3pid_allowed(self, username: str, registration: bool) -> None:
         """Tests that the "is_3pid_allowed" module callback is called correctly, using
         either /register or /account URLs depending on the arguments.
 
@@ -907,7 +910,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         client is trying to register.
         """
 
-        async def callback(uia_results, params):
+        async def callback(uia_results: JsonDict, params: JsonDict) -> str:
             self.assertIn(LoginType.DUMMY, uia_results)
             username = params["username"]
             return username + "-foo"
@@ -950,12 +953,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
     def _send_password_login(self, user: str, password: str) -> FakeChannel:
         return self._send_login(type="m.login.password", user=user, password=password)
 
-    def _send_login(self, type, user, **params) -> FakeChannel:
-        params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
+    def _send_login(self, type: str, user: str, **extra_params: str) -> FakeChannel:
+        params = {"identifier": {"type": "m.id.user", "user": user}, "type": type}
+        params.update(extra_params)
         channel = self.make_request("POST", "/_matrix/client/r0/login", params)
         return channel
 
-    def _start_delete_device_session(self, access_token, device_id) -> str:
+    def _start_delete_device_session(self, access_token: str, device_id: str) -> str:
         """Make an initial delete device request, and return the UI Auth session ID"""
         channel = self._delete_device(access_token, device_id)
         self.assertEqual(channel.code, 401)