diff options
author | Andrew Morgan <andrew@amorgan.xyz> | 2023-03-09 14:17:12 +0000 |
---|---|---|
committer | Andrew Morgan <andrew@amorgan.xyz> | 2023-03-09 16:50:31 +0000 |
commit | 46c0ab559b3713a99c9ac6453243c16ee66b1a3b (patch) | |
tree | 9f978e116f10f84dc21b9b51ecf7e57383a2a3ea | |
parent | Move callback-related code from the BackgroundUpdater to its own class (diff) | |
download | synapse-46c0ab559b3713a99c9ac6453243c16ee66b1a3b.tar.xz |
Move callback-related code from the PasswordAuthProvider to its own class
-rw-r--r-- | synapse/handlers/auth.py | 140 | ||||
-rw-r--r-- | synapse/module_api/__init__.py | 21 | ||||
-rw-r--r-- | synapse/module_api/callbacks/__init__.py | 2 | ||||
-rw-r--r-- | synapse/module_api/callbacks/password_auth_provider_callbacks.py | 138 | ||||
-rw-r--r-- | synapse/server.py | 2 | ||||
-rw-r--r-- | tests/handlers/test_password_providers.py | 10 |
6 files changed, 178 insertions, 135 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 308e38edea..f4a7d0f558 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -65,6 +65,10 @@ from synapse.http.server import finish_request, respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.module_api.callbacks.password_auth_provider_callbacks import ( + CHECK_3PID_AUTH_CALLBACK, + ON_LOGGED_OUT_CALLBACK, +) from synapse.storage.databases.main.registration import ( LoginTokenExpired, LoginTokenLookupResult, @@ -1096,7 +1100,7 @@ class AuthHandler: return self._password_enabled_for_login and self._password_localdb_enabled def get_supported_login_types(self) -> Iterable[str]: - """Get a the login types supported for the /login API + """Get the login types supported for the /login API By default this is just 'm.login.password' (unless password_enabled is False in the config file), but password auth providers can provide @@ -1999,124 +2003,16 @@ def load_single_legacy_password_auth_provider( ) -CHECK_3PID_AUTH_CALLBACK = Callable[ - [str, str, str], - Awaitable[ - Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]] - ], -] -ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable] -CHECK_AUTH_CALLBACK = Callable[ - [str, str, JsonDict], - Awaitable[ - Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]] - ], -] -GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[ - [JsonDict, JsonDict], - Awaitable[Optional[str]], -] -GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[ - [JsonDict, JsonDict], - Awaitable[Optional[str]], -] -IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]] - - class PasswordAuthProvider: """ A class that the AuthHandler calls when authenticating users It allows modules to provide alternative methods for authentication """ - def __init__(self) -> None: - # lists of callbacks - self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = [] - self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = [] - self.get_username_for_registration_callbacks: List[ - GET_USERNAME_FOR_REGISTRATION_CALLBACK - ] = [] - self.get_displayname_for_registration_callbacks: List[ - GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK - ] = [] - self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = [] - - # Mapping from login type to login parameters - self._supported_login_types: Dict[str, Tuple[str, ...]] = {} - - # Mapping from login type to auth checker callbacks - self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {} - - def register_password_auth_provider_callbacks( - self, - check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None, - on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None, - is_3pid_allowed: Optional[IS_3PID_ALLOWED_CALLBACK] = None, - auth_checkers: Optional[ - Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK] - ] = None, - get_username_for_registration: Optional[ - GET_USERNAME_FOR_REGISTRATION_CALLBACK - ] = None, - get_displayname_for_registration: Optional[ - GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK - ] = None, - ) -> None: - # Register check_3pid_auth callback - if check_3pid_auth is not None: - self.check_3pid_auth_callbacks.append(check_3pid_auth) - - # register on_logged_out callback - if on_logged_out is not None: - self.on_logged_out_callbacks.append(on_logged_out) - - if auth_checkers is not None: - # register a new supported login_type - # Iterate through all of the types being registered - for (login_type, fields), callback in auth_checkers.items(): - # Note: fields may be empty here. This would allow a modules auth checker to - # be called with just 'login_type' and no password or other secrets - - # Need to check that all the field names are strings or may get nasty errors later - for f in fields: - if not isinstance(f, str): - raise RuntimeError( - "A module tried to register support for login type: %s with parameters %s" - " but all parameter names must be strings" - % (login_type, fields) - ) - - # 2 modules supporting the same login type must expect the same fields - # e.g. 1 can't expect "pass" if the other expects "password" - # so throw an exception if that happens - if login_type not in self._supported_login_types.get(login_type, []): - self._supported_login_types[login_type] = fields - else: - fields_currently_supported = self._supported_login_types.get( - login_type - ) - if fields_currently_supported != fields: - raise RuntimeError( - "A module tried to register support for login type: %s with parameters %s" - " but another module had already registered support for that type with parameters %s" - % (login_type, fields, fields_currently_supported) - ) - - # Add the new method to the list of auth_checker_callbacks for this login type - self.auth_checker_callbacks.setdefault(login_type, []).append(callback) - - if get_username_for_registration is not None: - self.get_username_for_registration_callbacks.append( - get_username_for_registration, - ) - - if get_displayname_for_registration is not None: - self.get_displayname_for_registration_callbacks.append( - get_displayname_for_registration, - ) - - if is_3pid_allowed is not None: - self.is_3pid_allowed_callbacks.append(is_3pid_allowed) + def __init__(self, hs: "HomeServer") -> None: + self._module_api_callbacks = ( + hs.get_module_api_callbacks().password_auth_provider + ) def get_supported_login_types(self) -> Mapping[str, Iterable[str]]: """Get the login types supported by this password provider @@ -2126,7 +2022,7 @@ class PasswordAuthProvider: to the /login API. """ - return self._supported_login_types + return self._module_api_callbacks.supported_login_types async def check_auth( self, username: str, login_type: str, login_dict: JsonDict @@ -2149,7 +2045,7 @@ class PasswordAuthProvider: # Go through all callbacks for the login type until one returns with a value # other than None (i.e. until a callback returns a success) - for callback in self.auth_checker_callbacks[login_type]: + for callback in self._module_api_callbacks.auth_checker_callbacks[login_type]: try: result = await delay_cancellation( callback(username, login_type, login_dict) @@ -2214,7 +2110,7 @@ class PasswordAuthProvider: # (user_id, callback_func), where callback_func should be run # after we've finished everything else - for callback in self.check_3pid_auth_callbacks: + for callback in self._module_api_callbacks.check_3pid_auth_callbacks: try: result = await delay_cancellation(callback(medium, address, password)) except CancelledError: @@ -2272,7 +2168,7 @@ class PasswordAuthProvider: self, user_id: str, device_id: Optional[str], access_token: str ) -> None: # call all of the on_logged_out callbacks - for callback in self.on_logged_out_callbacks: + for callback in self._module_api_callbacks.on_logged_out_callbacks: try: await callback(user_id, device_id, access_token) except Exception as e: @@ -2297,7 +2193,9 @@ class PasswordAuthProvider: The localpart to use when registering this user, or None if no module returned a localpart. """ - for callback in self.get_username_for_registration_callbacks: + for ( + callback + ) in self._module_api_callbacks.get_username_for_registration_callbacks: try: res = await delay_cancellation(callback(uia_results, params)) @@ -2342,7 +2240,9 @@ class PasswordAuthProvider: A tuple which first element is the display name, and the second is an MXC URL to the user's avatar. """ - for callback in self.get_displayname_for_registration_callbacks: + for ( + callback + ) in self._module_api_callbacks.get_displayname_for_registration_callbacks: try: res = await delay_cancellation(callback(uia_results, params)) @@ -2385,7 +2285,7 @@ class PasswordAuthProvider: Returns: Whether the 3PID is allowed to be bound on this homeserver """ - for callback in self.is_3pid_allowed_callbacks: + for callback in self._module_api_callbacks.is_3pid_allowed_callbacks: try: res = await delay_cancellation(callback(medium, address, registration)) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index cb09423cb3..7d34f621f4 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -42,15 +42,7 @@ from synapse.events import EventBase from synapse.events.presence_router import PresenceRouter from synapse.events.spamcheck import SpamChecker from synapse.handlers.account_data import ON_ACCOUNT_DATA_UPDATED_CALLBACK -from synapse.handlers.auth import ( - CHECK_3PID_AUTH_CALLBACK, - CHECK_AUTH_CALLBACK, - GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK, - GET_USERNAME_FOR_REGISTRATION_CALLBACK, - IS_3PID_ALLOWED_CALLBACK, - ON_LOGGED_OUT_CALLBACK, - AuthHandler, -) +from synapse.handlers.auth import AuthHandler from synapse.handlers.device import DeviceHandler from synapse.handlers.push_rules import RuleSpec, check_actions from synapse.http.client import SimpleHttpClient @@ -79,6 +71,14 @@ from synapse.module_api.callbacks.background_updater_callbacks import ( MIN_BATCH_SIZE_CALLBACK, ON_UPDATE_CALLBACK, ) +from synapse.module_api.callbacks.password_auth_provider_callbacks import ( + CHECK_3PID_AUTH_CALLBACK, + CHECK_AUTH_CALLBACK, + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK, + GET_USERNAME_FOR_REGISTRATION_CALLBACK, + IS_3PID_ALLOWED_CALLBACK, + ON_LOGGED_OUT_CALLBACK, +) from synapse.module_api.callbacks.presence_router_callbacks import ( GET_INTERESTED_USERS_CALLBACK, GET_USERS_FOR_STATES_CALLBACK, @@ -271,7 +271,6 @@ class ModuleApi: self._public_room_list_manager = PublicRoomListManager(hs) self._account_data_manager = AccountDataManager(hs) - self._password_auth_provider = hs.get_password_auth_provider() self._account_data_handler = hs.get_account_data_handler() ################################################################################# @@ -417,7 +416,7 @@ class ModuleApi: Added in Synapse v1.46.0. """ - return self._password_auth_provider.register_password_auth_provider_callbacks( + return self._callbacks.password_auth_provider.register_callbacks( check_3pid_auth=check_3pid_auth, on_logged_out=on_logged_out, is_3pid_allowed=is_3pid_allowed, diff --git a/synapse/module_api/callbacks/__init__.py b/synapse/module_api/callbacks/__init__.py index a35e7bf511..add3f2938d 100644 --- a/synapse/module_api/callbacks/__init__.py +++ b/synapse/module_api/callbacks/__init__.py @@ -14,6 +14,7 @@ from .account_validity_callbacks import AccountValidityModuleApiCallbacks from .background_updater_callbacks import BackgroundUpdaterModuleApiCallbacks +from .password_auth_provider_callbacks import PasswordAuthProviderModuleApiCallbacks from .presence_router_callbacks import PresenceRouterModuleApiCallbacks from .spam_checker_callbacks import SpamCheckerModuleApiCallbacks from .third_party_event_rules_callbacks import ThirdPartyEventRulesModuleApiCallbacks @@ -27,6 +28,7 @@ class ModuleApiCallbacks: def __init__(self) -> None: self.account_validity = AccountValidityModuleApiCallbacks() self.background_updater = BackgroundUpdaterModuleApiCallbacks() + self.password_auth_provider = PasswordAuthProviderModuleApiCallbacks() self.presence_router = PresenceRouterModuleApiCallbacks() self.spam_checker = SpamCheckerModuleApiCallbacks() self.third_party_event_rules = ThirdPartyEventRulesModuleApiCallbacks() diff --git a/synapse/module_api/callbacks/password_auth_provider_callbacks.py b/synapse/module_api/callbacks/password_auth_provider_callbacks.py new file mode 100644 index 0000000000..fddaa24abd --- /dev/null +++ b/synapse/module_api/callbacks/password_auth_provider_callbacks.py @@ -0,0 +1,138 @@ +# Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2019 - 2020, 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple + +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.module_api import LoginResponse + +logger = logging.getLogger(__name__) + + +CHECK_3PID_AUTH_CALLBACK = Callable[ + [str, str, str], + Awaitable[ + Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]] + ], +] +ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable] +CHECK_AUTH_CALLBACK = Callable[ + [str, str, JsonDict], + Awaitable[ + Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]] + ], +] +GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[ + [JsonDict, JsonDict], + Awaitable[Optional[str]], +] +GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[ + [JsonDict, JsonDict], + Awaitable[Optional[str]], +] +IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]] + + +class PasswordAuthProviderModuleApiCallbacks: + def __init__(self) -> None: + # Mapping from login type to login parameters + self.supported_login_types: Dict[str, Tuple[str, ...]] = {} + + self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = [] + self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = [] + self.get_username_for_registration_callbacks: List[ + GET_USERNAME_FOR_REGISTRATION_CALLBACK + ] = [] + self.get_displayname_for_registration_callbacks: List[ + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK + ] = [] + self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = [] + + # Mapping from login type to auth checker callbacks + self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {} + + def register_callbacks( + self, + check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None, + on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None, + is_3pid_allowed: Optional[IS_3PID_ALLOWED_CALLBACK] = None, + auth_checkers: Optional[ + Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK] + ] = None, + get_username_for_registration: Optional[ + GET_USERNAME_FOR_REGISTRATION_CALLBACK + ] = None, + get_displayname_for_registration: Optional[ + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK + ] = None, + ) -> None: + # Register check_3pid_auth callback + if check_3pid_auth is not None: + self.check_3pid_auth_callbacks.append(check_3pid_auth) + + # register on_logged_out callback + if on_logged_out is not None: + self.on_logged_out_callbacks.append(on_logged_out) + + if auth_checkers is not None: + # register a new supported login_type + # Iterate through all of the types being registered + for (login_type, fields), callback in auth_checkers.items(): + # Note: fields may be empty here. This would allow a modules auth checker to + # be called with just 'login_type' and no password or other secrets + + # Need to check that all the field names are strings or may get nasty errors later + for f in fields: + if not isinstance(f, str): + raise RuntimeError( + "A module tried to register support for login type: %s with parameters %s" + " but all parameter names must be strings" + % (login_type, fields) + ) + + # 2 modules supporting the same login type must expect the same fields + # e.g. 1 can't expect "pass" if the other expects "password" + # so throw an exception if that happens + if login_type not in self.supported_login_types.get(login_type, []): + self.supported_login_types[login_type] = fields + else: + fields_currently_supported = self.supported_login_types.get( + login_type + ) + if fields_currently_supported != fields: + raise RuntimeError( + "A module tried to register support for login type: %s with parameters %s" + " but another module had already registered support for that type with parameters %s" + % (login_type, fields, fields_currently_supported) + ) + + # Add the new method to the list of auth_checker_callbacks for this login type + self.auth_checker_callbacks.setdefault(login_type, []).append(callback) + + if get_username_for_registration is not None: + self.get_username_for_registration_callbacks.append( + get_username_for_registration, + ) + + if get_displayname_for_registration is not None: + self.get_displayname_for_registration_callbacks.append( + get_displayname_for_registration, + ) + + if is_3pid_allowed is not None: + self.is_3pid_allowed_callbacks.append(is_3pid_allowed) diff --git a/synapse/server.py b/synapse/server.py index 9bd374ceae..729dab1f8d 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -674,7 +674,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_password_auth_provider(self) -> PasswordAuthProvider: - return PasswordAuthProvider() + return PasswordAuthProvider(self) @cache_in_self def get_room_member_handler(self) -> RoomMemberHandler: diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index aa91bc0a3d..dd55b0d285 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -727,7 +727,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.called = True on_logged_out = Mock(side_effect=on_logged_out) - self.hs.get_password_auth_provider().on_logged_out_callbacks.append( + self.hs.get_module_api_callbacks().password_auth_provider.on_logged_out_callbacks.append( on_logged_out ) @@ -857,7 +857,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): ) m = Mock(return_value=make_awaitable(False)) - self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m] + self.hs.get_module_api_callbacks().password_auth_provider.is_3pid_allowed_callbacks = [ + m + ] self.register_user(username, "password") tok = self.login(username, "password") @@ -887,7 +889,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): m.assert_called_once_with("email", "foo@test.com", registration) m = Mock(return_value=make_awaitable(True)) - self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m] + self.hs.get_module_api_callbacks().password_auth_provider.is_3pid_allowed_callbacks = [ + m + ] channel = self.make_request( "POST", |