diff options
author | Azrenbeth <77782548+Azrenbeth@users.noreply.github.com> | 2021-10-13 12:21:52 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-10-13 11:21:52 +0000 |
commit | cdd308845ba22fef22a39ed5bf904b438e48b491 (patch) | |
tree | bcd28fef3c151d3e7cc484e0e416c58c4ba3c679 /synapse/handlers | |
parent | Be more lenient when parsing the version for oEmbed responses. (#11065) (diff) | |
download | synapse-cdd308845ba22fef22a39ed5bf904b438e48b491.tar.xz |
Port the Password Auth Providers module interface to the new generic interface (#10548)
Co-authored-by: Azrenbeth <7782548+Azrenbeth@users.noreply.github.com> Co-authored-by: Brendan Abolivier <babolivier@matrix.org>
Diffstat (limited to 'synapse/handlers')
-rw-r--r-- | synapse/handlers/auth.py | 528 |
1 files changed, 388 insertions, 140 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index f4612a5b92..ebe75a9e9b 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -200,46 +200,13 @@ class AuthHandler: self.bcrypt_rounds = hs.config.registration.bcrypt_rounds - # we can't use hs.get_module_api() here, because to do so will create an - # import loop. - # - # TODO: refactor this class to separate the lower-level stuff that - # ModuleApi can use from the higher-level stuff that uses ModuleApi, as - # better way to break the loop - account_handler = ModuleApi(hs, self) - - self.password_providers = [ - PasswordProvider.load(module, config, account_handler) - for module, config in hs.config.authproviders.password_providers - ] - - logger.info("Extra password_providers: %s", self.password_providers) + self.password_auth_provider = hs.get_password_auth_provider() self.hs = hs # FIXME better possibility to access registrationHandler later? self.macaroon_gen = hs.get_macaroon_generator() self._password_enabled = hs.config.auth.password_enabled self._password_localdb_enabled = hs.config.auth.password_localdb_enabled - # start out by assuming PASSWORD is enabled; we will remove it later if not. - login_types = set() - if self._password_localdb_enabled: - login_types.add(LoginType.PASSWORD) - - for provider in self.password_providers: - login_types.update(provider.get_supported_login_types().keys()) - - if not self._password_enabled: - login_types.discard(LoginType.PASSWORD) - - # Some clients just pick the first type in the list. In this case, we want - # them to use PASSWORD (rather than token or whatever), so we want to make sure - # that comes first, where it's present. - self._supported_login_types = [] - if LoginType.PASSWORD in login_types: - self._supported_login_types.append(LoginType.PASSWORD) - login_types.remove(LoginType.PASSWORD) - self._supported_login_types.extend(login_types) - # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. self._failed_uia_attempts_ratelimiter = Ratelimiter( @@ -427,11 +394,10 @@ class AuthHandler: ui_auth_types.add(LoginType.PASSWORD) # also allow auth from password providers - for provider in self.password_providers: - for t in provider.get_supported_login_types().keys(): - if t == LoginType.PASSWORD and not self._password_enabled: - continue - ui_auth_types.add(t) + for t in self.password_auth_provider.get_supported_login_types().keys(): + if t == LoginType.PASSWORD and not self._password_enabled: + continue + ui_auth_types.add(t) # if sso is enabled, allow the user to log in via SSO iff they have a mapping # from sso to mxid. @@ -1038,7 +1004,25 @@ class AuthHandler: Returns: login types """ - return self._supported_login_types + # Load any login types registered by modules + # This is stored in the password_auth_provider so this doesn't trigger + # any callbacks + types = list(self.password_auth_provider.get_supported_login_types().keys()) + + # This list should include PASSWORD if (either _password_localdb_enabled is + # true or if one of the modules registered it) AND _password_enabled is true + # Also: + # Some clients just pick the first type in the list. In this case, we want + # them to use PASSWORD (rather than token or whatever), so we want to make sure + # that comes first, where it's present. + if LoginType.PASSWORD in types: + types.remove(LoginType.PASSWORD) + if self._password_enabled: + types.insert(0, LoginType.PASSWORD) + elif self._password_localdb_enabled and self._password_enabled: + types.insert(0, LoginType.PASSWORD) + + return types async def validate_login( self, @@ -1217,15 +1201,20 @@ class AuthHandler: known_login_type = False - for provider in self.password_providers: - supported_login_types = provider.get_supported_login_types() - if login_type not in supported_login_types: - # this password provider doesn't understand this login type - continue - + # Check if login_type matches a type registered by one of the modules + # We don't need to remove LoginType.PASSWORD from the list if password login is + # disabled, since if that were the case then by this point we know that the + # login_type is not LoginType.PASSWORD + supported_login_types = self.password_auth_provider.get_supported_login_types() + # check if the login type being used is supported by a module + if login_type in supported_login_types: + # Make a note that this login type is supported by the server known_login_type = True + # Get all the fields expected for this login types login_fields = supported_login_types[login_type] + # go through the login submission and keep track of which required fields are + # provided/not provided missing_fields = [] login_dict = {} for f in login_fields: @@ -1233,6 +1222,7 @@ class AuthHandler: missing_fields.append(f) else: login_dict[f] = login_submission[f] + # raise an error if any of the expected fields for that login type weren't provided if missing_fields: raise SynapseError( 400, @@ -1240,10 +1230,15 @@ class AuthHandler: % (login_type, missing_fields), ) - result = await provider.check_auth(username, login_type, login_dict) + # call all of the check_auth hooks for that login_type + # it will return a result once the first success is found (or None otherwise) + result = await self.password_auth_provider.check_auth( + username, login_type, login_dict + ) if result: return result + # if no module managed to authenticate the user, then fallback to built in password based auth if login_type == LoginType.PASSWORD and self._password_localdb_enabled: known_login_type = True @@ -1282,11 +1277,16 @@ class AuthHandler: completed login/registration, or `None`. If authentication was unsuccessful, `user_id` and `callback` are both `None`. """ - for provider in self.password_providers: - result = await provider.check_3pid_auth(medium, address, password) - if result: - return result + # call all of the check_3pid_auth callbacks + # Result will be from the first callback that returns something other than None + # If all the callbacks return None, then result is also set to None + result = await self.password_auth_provider.check_3pid_auth( + medium, address, password + ) + if result: + return result + # if result is None then return (None, None) return None, None async def _check_local_password(self, user_id: str, password: str) -> Optional[str]: @@ -1365,13 +1365,12 @@ class AuthHandler: user_info = await self.auth.get_user_by_access_token(access_token) await self.store.delete_access_token(access_token) - # see if any of our auth providers want to know about this - for provider in self.password_providers: - await provider.on_logged_out( - user_id=user_info.user_id, - device_id=user_info.device_id, - access_token=access_token, - ) + # see if any modules want to know about this + await self.password_auth_provider.on_logged_out( + user_id=user_info.user_id, + device_id=user_info.device_id, + access_token=access_token, + ) # delete pushers associated with this access token if user_info.token_id is not None: @@ -1398,12 +1397,11 @@ class AuthHandler: user_id, except_token_id=except_token_id, device_id=device_id ) - # see if any of our auth providers want to know about this - for provider in self.password_providers: - for token, _, device_id in tokens_and_devices: - await provider.on_logged_out( - user_id=user_id, device_id=device_id, access_token=token - ) + # see if any modules want to know about this + for token, _, device_id in tokens_and_devices: + await self.password_auth_provider.on_logged_out( + user_id=user_id, device_id=device_id, access_token=token + ) # delete pushers associated with the access tokens await self.hs.get_pusherpool().remove_pushers_by_access_token( @@ -1811,40 +1809,228 @@ class MacaroonGenerator: return macaroon -class PasswordProvider: - """Wrapper for a password auth provider module +def load_legacy_password_auth_providers(hs: "HomeServer") -> None: + module_api = hs.get_module_api() + for module, config in hs.config.authproviders.password_providers: + load_single_legacy_password_auth_provider( + module=module, config=config, api=module_api + ) - This class abstracts out all of the backwards-compatibility hacks for - password providers, to provide a consistent interface. - """ - @classmethod - def load( - cls, module: Type, config: JsonDict, module_api: ModuleApi - ) -> "PasswordProvider": - try: - pp = module(config=config, account_handler=module_api) - except Exception as e: - logger.error("Error while initializing %r: %s", module, e) - raise - return cls(pp, module_api) +def load_single_legacy_password_auth_provider( + module: Type, config: JsonDict, api: ModuleApi +) -> None: + try: + provider = module(config=config, account_handler=api) + except Exception as e: + logger.error("Error while initializing %r: %s", module, e) + raise + + # The known hooks. If a module implements a method who's name appears in this set + # we'll want to register it + password_auth_provider_methods = { + "check_3pid_auth", + "on_logged_out", + } + + # All methods that the module provides should be async, but this wasn't enforced + # in the old module system, so we wrap them if needed + def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]: + # f might be None if the callback isn't implemented by the module. In this + # case we don't want to register a callback at all so we return None. + if f is None: + return None + + # We need to wrap check_password because its old form would return a boolean + # but we now want it to behave just like check_auth() and return the matrix id of + # the user if authentication succeeded or None otherwise + if f.__name__ == "check_password": + + async def wrapped_check_password( + username: str, login_type: str, login_dict: JsonDict + ) -> Optional[Tuple[str, Optional[Callable]]]: + # We've already made sure f is not None above, but mypy doesn't do well + # across function boundaries so we need to tell it f is definitely not + # None. + assert f is not None + + matrix_user_id = api.get_qualified_user_id(username) + password = login_dict["password"] + + is_valid = await f(matrix_user_id, password) + + if is_valid: + return matrix_user_id, None + + return None - def __init__(self, pp: "PasswordProvider", module_api: ModuleApi): - self._pp = pp - self._module_api = module_api + return wrapped_check_password + + # We need to wrap check_auth as in the old form it could return + # just a str, but now it must return Optional[Tuple[str, Optional[Callable]] + if f.__name__ == "check_auth": + + async def wrapped_check_auth( + username: str, login_type: str, login_dict: JsonDict + ) -> Optional[Tuple[str, Optional[Callable]]]: + # We've already made sure f is not None above, but mypy doesn't do well + # across function boundaries so we need to tell it f is definitely not + # None. + assert f is not None + + result = await f(username, login_type, login_dict) + + if isinstance(result, str): + return result, None + + return result + + return wrapped_check_auth + + # We need to wrap check_3pid_auth as in the old form it could return + # just a str, but now it must return Optional[Tuple[str, Optional[Callable]] + if f.__name__ == "check_3pid_auth": + + async def wrapped_check_3pid_auth( + medium: str, address: str, password: str + ) -> Optional[Tuple[str, Optional[Callable]]]: + # We've already made sure f is not None above, but mypy doesn't do well + # across function boundaries so we need to tell it f is definitely not + # None. + assert f is not None + + result = await f(medium, address, password) + + if isinstance(result, str): + return result, None + + return result - self._supported_login_types = {} + return wrapped_check_3pid_auth - # grandfather in check_password support - if hasattr(self._pp, "check_password"): - self._supported_login_types[LoginType.PASSWORD] = ("password",) + def run(*args: Tuple, **kwargs: Dict) -> Awaitable: + # mypy doesn't do well across function boundaries so we need to tell it + # f is definitely not None. + assert f is not None - g = getattr(self._pp, "get_supported_login_types", None) - if g: - self._supported_login_types.update(g()) + return maybe_awaitable(f(*args, **kwargs)) - def __str__(self) -> str: - return str(self._pp) + return run + + # populate hooks with the implemented methods, wrapped with async_wrapper + hooks = { + hook: async_wrapper(getattr(provider, hook, None)) + for hook in password_auth_provider_methods + } + + supported_login_types = {} + # call get_supported_login_types and add that to the dict + g = getattr(provider, "get_supported_login_types", None) + if g is not None: + # Note the old module style also called get_supported_login_types at loading time + # and it is synchronous + supported_login_types.update(g()) + + auth_checkers = {} + # Legacy modules have a check_auth method which expects to be called with one of + # the keys returned by get_supported_login_types. New style modules register a + # dictionary of login_type->check_auth_method mappings + check_auth = async_wrapper(getattr(provider, "check_auth", None)) + if check_auth is not None: + for login_type, fields in supported_login_types.items(): + # need tuple(fields) since fields can be any Iterable type (so may not be hashable) + auth_checkers[(login_type, tuple(fields))] = check_auth + + # if it has a "check_password" method then it should handle all auth checks + # with login type of LoginType.PASSWORD + check_password = async_wrapper(getattr(provider, "check_password", None)) + if check_password is not None: + # need to use a tuple here for ("password",) not a list since lists aren't hashable + auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password + + api.register_password_auth_provider_callbacks(hooks, auth_checkers=auth_checkers) + + +CHECK_3PID_AUTH_CALLBACK = Callable[ + [str, str, str], + Awaitable[ + Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]] + ], +] +ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable] +CHECK_AUTH_CALLBACK = Callable[ + [str, str, JsonDict], + Awaitable[ + Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]] + ], +] + + +class PasswordAuthProvider: + """ + A class that the AuthHandler calls when authenticating users + It allows modules to provide alternative methods for authentication + """ + + def __init__(self) -> None: + # lists of callbacks + self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = [] + self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = [] + + # Mapping from login type to login parameters + self._supported_login_types: Dict[str, Iterable[str]] = {} + + # Mapping from login type to auth checker callbacks + self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {} + + def register_password_auth_provider_callbacks( + self, + check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None, + on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None, + auth_checkers: Optional[Dict[Tuple[str, Tuple], CHECK_AUTH_CALLBACK]] = None, + ) -> None: + # Register check_3pid_auth callback + if check_3pid_auth is not None: + self.check_3pid_auth_callbacks.append(check_3pid_auth) + + # register on_logged_out callback + if on_logged_out is not None: + self.on_logged_out_callbacks.append(on_logged_out) + + if auth_checkers is not None: + # register a new supported login_type + # Iterate through all of the types being registered + for (login_type, fields), callback in auth_checkers.items(): + # Note: fields may be empty here. This would allow a modules auth checker to + # be called with just 'login_type' and no password or other secrets + + # Need to check that all the field names are strings or may get nasty errors later + for f in fields: + if not isinstance(f, str): + raise RuntimeError( + "A module tried to register support for login type: %s with parameters %s" + " but all parameter names must be strings" + % (login_type, fields) + ) + + # 2 modules supporting the same login type must expect the same fields + # e.g. 1 can't expect "pass" if the other expects "password" + # so throw an exception if that happens + if login_type not in self._supported_login_types.get(login_type, []): + self._supported_login_types[login_type] = fields + else: + fields_currently_supported = self._supported_login_types.get( + login_type + ) + if fields_currently_supported != fields: + raise RuntimeError( + "A module tried to register support for login type: %s with parameters %s" + " but another module had already registered support for that type with parameters %s" + % (login_type, fields, fields_currently_supported) + ) + + # Add the new method to the list of auth_checker_callbacks for this login type + self.auth_checker_callbacks.setdefault(login_type, []).append(callback) def get_supported_login_types(self) -> Mapping[str, Iterable[str]]: """Get the login types supported by this password provider @@ -1852,20 +2038,15 @@ class PasswordProvider: Returns a map from a login type identifier (such as m.login.password) to an iterable giving the fields which must be provided by the user in the submission to the /login API. - - This wrapper adds m.login.password to the list if the underlying password - provider supports the check_password() api. """ + return self._supported_login_types async def check_auth( self, username: str, login_type: str, login_dict: JsonDict - ) -> Optional[Tuple[str, Optional[Callable]]]: + ) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]: """Check if the user has presented valid login credentials - This wrapper also calls check_password() if the underlying password provider - supports the check_password() api and the login type is m.login.password. - Args: username: user id presented by the client. Either an MXID or an unqualified username. @@ -1879,63 +2060,130 @@ class PasswordProvider: user, and `callback` is an optional callback which will be called with the result from the /login call (including access_token, device_id, etc.) """ - # first grandfather in a call to check_password - if login_type == LoginType.PASSWORD: - check_password = getattr(self._pp, "check_password", None) - if check_password: - qualified_user_id = self._module_api.get_qualified_user_id(username) - is_valid = await check_password( - qualified_user_id, login_dict["password"] - ) - if is_valid: - return qualified_user_id, None - check_auth = getattr(self._pp, "check_auth", None) - if not check_auth: - return None - result = await check_auth(username, login_type, login_dict) + # Go through all callbacks for the login type until one returns with a value + # other than None (i.e. until a callback returns a success) + for callback in self.auth_checker_callbacks[login_type]: + try: + result = await callback(username, login_type, login_dict) + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + continue - # Check if the return value is a str or a tuple - if isinstance(result, str): - # If it's a str, set callback function to None - return result, None + if result is not None: + # Check that the callback returned a Tuple[str, Optional[Callable]] + # "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks + # result is always the right type, but as it is 3rd party code it might not be + + if not isinstance(result, tuple) or len(result) != 2: + logger.warning( + "Wrong type returned by module API callback %s: %s, expected" + " Optional[Tuple[str, Optional[Callable]]]", + callback, + result, + ) + continue - return result + # pull out the two parts of the tuple so we can do type checking + str_result, callback_result = result + + # the 1st item in the tuple should be a str + if not isinstance(str_result, str): + logger.warning( # type: ignore[unreachable] + "Wrong type returned by module API callback %s: %s, expected" + " Optional[Tuple[str, Optional[Callable]]]", + callback, + result, + ) + continue + + # the second should be Optional[Callable] + if callback_result is not None: + if not callable(callback_result): + logger.warning( # type: ignore[unreachable] + "Wrong type returned by module API callback %s: %s, expected" + " Optional[Tuple[str, Optional[Callable]]]", + callback, + result, + ) + continue + + # The result is a (str, Optional[callback]) tuple so return the successful result + return result + + # If this point has been reached then none of the callbacks successfully authenticated + # the user so return None + return None async def check_3pid_auth( self, medium: str, address: str, password: str - ) -> Optional[Tuple[str, Optional[Callable]]]: - g = getattr(self._pp, "check_3pid_auth", None) - if not g: - return None - + ) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]: # This function is able to return a deferred that either # resolves None, meaning authentication failure, or upon # success, to a str (which is the user_id) or a tuple of # (user_id, callback_func), where callback_func should be run # after we've finished everything else - result = await g(medium, address, password) - # Check if the return value is a str or a tuple - if isinstance(result, str): - # If it's a str, set callback function to None - return result, None + for callback in self.check_3pid_auth_callbacks: + try: + result = await callback(medium, address, password) + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + continue - return result + if result is not None: + # Check that the callback returned a Tuple[str, Optional[Callable]] + # "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks + # result is always the right type, but as it is 3rd party code it might not be + + if not isinstance(result, tuple) or len(result) != 2: + logger.warning( + "Wrong type returned by module API callback %s: %s, expected" + " Optional[Tuple[str, Optional[Callable]]]", + callback, + result, + ) + continue + + # pull out the two parts of the tuple so we can do type checking + str_result, callback_result = result + + # the 1st item in the tuple should be a str + if not isinstance(str_result, str): + logger.warning( # type: ignore[unreachable] + "Wrong type returned by module API callback %s: %s, expected" + " Optional[Tuple[str, Optional[Callable]]]", + callback, + result, + ) + continue + + # the second should be Optional[Callable] + if callback_result is not None: + if not callable(callback_result): + logger.warning( # type: ignore[unreachable] + "Wrong type returned by module API callback %s: %s, expected" + " Optional[Tuple[str, Optional[Callable]]]", + callback, + result, + ) + continue + + # The result is a (str, Optional[callback]) tuple so return the successful result + return result + + # If this point has been reached then none of the callbacks successfully authenticated + # the user so return None + return None async def on_logged_out( self, user_id: str, device_id: Optional[str], access_token: str ) -> None: - g = getattr(self._pp, "on_logged_out", None) - if not g: - return - # This might return an awaitable, if it does block the log out - # until it completes. - await maybe_awaitable( - g( - user_id=user_id, - device_id=device_id, - access_token=access_token, - ) - ) + # call all of the on_logged_out callbacks + for callback in self.on_logged_out_callbacks: + try: + callback(user_id, device_id, access_token) + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + continue |