summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py528
1 files changed, 388 insertions, 140 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index f4612a5b92..ebe75a9e9b 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -200,46 +200,13 @@ class AuthHandler:
 
         self.bcrypt_rounds = hs.config.registration.bcrypt_rounds
 
-        # we can't use hs.get_module_api() here, because to do so will create an
-        # import loop.
-        #
-        # TODO: refactor this class to separate the lower-level stuff that
-        #   ModuleApi can use from the higher-level stuff that uses ModuleApi, as
-        #   better way to break the loop
-        account_handler = ModuleApi(hs, self)
-
-        self.password_providers = [
-            PasswordProvider.load(module, config, account_handler)
-            for module, config in hs.config.authproviders.password_providers
-        ]
-
-        logger.info("Extra password_providers: %s", self.password_providers)
+        self.password_auth_provider = hs.get_password_auth_provider()
 
         self.hs = hs  # FIXME better possibility to access registrationHandler later?
         self.macaroon_gen = hs.get_macaroon_generator()
         self._password_enabled = hs.config.auth.password_enabled
         self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
 
-        # start out by assuming PASSWORD is enabled; we will remove it later if not.
-        login_types = set()
-        if self._password_localdb_enabled:
-            login_types.add(LoginType.PASSWORD)
-
-        for provider in self.password_providers:
-            login_types.update(provider.get_supported_login_types().keys())
-
-        if not self._password_enabled:
-            login_types.discard(LoginType.PASSWORD)
-
-        # Some clients just pick the first type in the list. In this case, we want
-        # them to use PASSWORD (rather than token or whatever), so we want to make sure
-        # that comes first, where it's present.
-        self._supported_login_types = []
-        if LoginType.PASSWORD in login_types:
-            self._supported_login_types.append(LoginType.PASSWORD)
-            login_types.remove(LoginType.PASSWORD)
-        self._supported_login_types.extend(login_types)
-
         # Ratelimiter for failed auth during UIA. Uses same ratelimit config
         # as per `rc_login.failed_attempts`.
         self._failed_uia_attempts_ratelimiter = Ratelimiter(
@@ -427,11 +394,10 @@ class AuthHandler:
                     ui_auth_types.add(LoginType.PASSWORD)
 
         # also allow auth from password providers
-        for provider in self.password_providers:
-            for t in provider.get_supported_login_types().keys():
-                if t == LoginType.PASSWORD and not self._password_enabled:
-                    continue
-                ui_auth_types.add(t)
+        for t in self.password_auth_provider.get_supported_login_types().keys():
+            if t == LoginType.PASSWORD and not self._password_enabled:
+                continue
+            ui_auth_types.add(t)
 
         # if sso is enabled, allow the user to log in via SSO iff they have a mapping
         # from sso to mxid.
@@ -1038,7 +1004,25 @@ class AuthHandler:
         Returns:
             login types
         """
-        return self._supported_login_types
+        # Load any login types registered by modules
+        # This is stored in the password_auth_provider so this doesn't trigger
+        # any callbacks
+        types = list(self.password_auth_provider.get_supported_login_types().keys())
+
+        # This list should include PASSWORD if (either _password_localdb_enabled is
+        # true or if one of the modules registered it) AND _password_enabled is true
+        # Also:
+        # Some clients just pick the first type in the list. In this case, we want
+        # them to use PASSWORD (rather than token or whatever), so we want to make sure
+        # that comes first, where it's present.
+        if LoginType.PASSWORD in types:
+            types.remove(LoginType.PASSWORD)
+            if self._password_enabled:
+                types.insert(0, LoginType.PASSWORD)
+        elif self._password_localdb_enabled and self._password_enabled:
+            types.insert(0, LoginType.PASSWORD)
+
+        return types
 
     async def validate_login(
         self,
@@ -1217,15 +1201,20 @@ class AuthHandler:
 
         known_login_type = False
 
-        for provider in self.password_providers:
-            supported_login_types = provider.get_supported_login_types()
-            if login_type not in supported_login_types:
-                # this password provider doesn't understand this login type
-                continue
-
+        # Check if login_type matches a type registered by one of the modules
+        # We don't need to remove LoginType.PASSWORD from the list if password login is
+        # disabled, since if that were the case then by this point we know that the
+        # login_type is not LoginType.PASSWORD
+        supported_login_types = self.password_auth_provider.get_supported_login_types()
+        # check if the login type being used is supported by a module
+        if login_type in supported_login_types:
+            # Make a note that this login type is supported by the server
             known_login_type = True
+            # Get all the fields expected for this login types
             login_fields = supported_login_types[login_type]
 
+            # go through the login submission and keep track of which required fields are
+            # provided/not provided
             missing_fields = []
             login_dict = {}
             for f in login_fields:
@@ -1233,6 +1222,7 @@ class AuthHandler:
                     missing_fields.append(f)
                 else:
                     login_dict[f] = login_submission[f]
+            # raise an error if any of the expected fields for that login type weren't provided
             if missing_fields:
                 raise SynapseError(
                     400,
@@ -1240,10 +1230,15 @@ class AuthHandler:
                     % (login_type, missing_fields),
                 )
 
-            result = await provider.check_auth(username, login_type, login_dict)
+            # call all of the check_auth hooks for that login_type
+            # it will return a result once the first success is found (or None otherwise)
+            result = await self.password_auth_provider.check_auth(
+                username, login_type, login_dict
+            )
             if result:
                 return result
 
+        # if no module managed to authenticate the user, then fallback to built in password based auth
         if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
             known_login_type = True
 
@@ -1282,11 +1277,16 @@ class AuthHandler:
             completed login/registration, or `None`. If authentication was
             unsuccessful, `user_id` and `callback` are both `None`.
         """
-        for provider in self.password_providers:
-            result = await provider.check_3pid_auth(medium, address, password)
-            if result:
-                return result
+        # call all of the check_3pid_auth callbacks
+        # Result will be from the first callback that returns something other than None
+        # If all the callbacks return None, then result is also set to None
+        result = await self.password_auth_provider.check_3pid_auth(
+            medium, address, password
+        )
+        if result:
+            return result
 
+        # if result is None then return (None, None)
         return None, None
 
     async def _check_local_password(self, user_id: str, password: str) -> Optional[str]:
@@ -1365,13 +1365,12 @@ class AuthHandler:
         user_info = await self.auth.get_user_by_access_token(access_token)
         await self.store.delete_access_token(access_token)
 
-        # see if any of our auth providers want to know about this
-        for provider in self.password_providers:
-            await provider.on_logged_out(
-                user_id=user_info.user_id,
-                device_id=user_info.device_id,
-                access_token=access_token,
-            )
+        # see if any modules want to know about this
+        await self.password_auth_provider.on_logged_out(
+            user_id=user_info.user_id,
+            device_id=user_info.device_id,
+            access_token=access_token,
+        )
 
         # delete pushers associated with this access token
         if user_info.token_id is not None:
@@ -1398,12 +1397,11 @@ class AuthHandler:
             user_id, except_token_id=except_token_id, device_id=device_id
         )
 
-        # see if any of our auth providers want to know about this
-        for provider in self.password_providers:
-            for token, _, device_id in tokens_and_devices:
-                await provider.on_logged_out(
-                    user_id=user_id, device_id=device_id, access_token=token
-                )
+        # see if any modules want to know about this
+        for token, _, device_id in tokens_and_devices:
+            await self.password_auth_provider.on_logged_out(
+                user_id=user_id, device_id=device_id, access_token=token
+            )
 
         # delete pushers associated with the access tokens
         await self.hs.get_pusherpool().remove_pushers_by_access_token(
@@ -1811,40 +1809,228 @@ class MacaroonGenerator:
         return macaroon
 
 
-class PasswordProvider:
-    """Wrapper for a password auth provider module
+def load_legacy_password_auth_providers(hs: "HomeServer") -> None:
+    module_api = hs.get_module_api()
+    for module, config in hs.config.authproviders.password_providers:
+        load_single_legacy_password_auth_provider(
+            module=module, config=config, api=module_api
+        )
 
-    This class abstracts out all of the backwards-compatibility hacks for
-    password providers, to provide a consistent interface.
-    """
 
-    @classmethod
-    def load(
-        cls, module: Type, config: JsonDict, module_api: ModuleApi
-    ) -> "PasswordProvider":
-        try:
-            pp = module(config=config, account_handler=module_api)
-        except Exception as e:
-            logger.error("Error while initializing %r: %s", module, e)
-            raise
-        return cls(pp, module_api)
+def load_single_legacy_password_auth_provider(
+    module: Type, config: JsonDict, api: ModuleApi
+) -> None:
+    try:
+        provider = module(config=config, account_handler=api)
+    except Exception as e:
+        logger.error("Error while initializing %r: %s", module, e)
+        raise
+
+    # The known hooks. If a module implements a method who's name appears in this set
+    # we'll want to register it
+    password_auth_provider_methods = {
+        "check_3pid_auth",
+        "on_logged_out",
+    }
+
+    # All methods that the module provides should be async, but this wasn't enforced
+    # in the old module system, so we wrap them if needed
+    def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
+        # f might be None if the callback isn't implemented by the module. In this
+        # case we don't want to register a callback at all so we return None.
+        if f is None:
+            return None
+
+        # We need to wrap check_password because its old form would return a boolean
+        # but we now want it to behave just like check_auth() and return the matrix id of
+        # the user if authentication succeeded or None otherwise
+        if f.__name__ == "check_password":
+
+            async def wrapped_check_password(
+                username: str, login_type: str, login_dict: JsonDict
+            ) -> Optional[Tuple[str, Optional[Callable]]]:
+                # We've already made sure f is not None above, but mypy doesn't do well
+                # across function boundaries so we need to tell it f is definitely not
+                # None.
+                assert f is not None
+
+                matrix_user_id = api.get_qualified_user_id(username)
+                password = login_dict["password"]
+
+                is_valid = await f(matrix_user_id, password)
+
+                if is_valid:
+                    return matrix_user_id, None
+
+                return None
 
-    def __init__(self, pp: "PasswordProvider", module_api: ModuleApi):
-        self._pp = pp
-        self._module_api = module_api
+            return wrapped_check_password
+
+        # We need to wrap check_auth as in the old form it could return
+        # just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
+        if f.__name__ == "check_auth":
+
+            async def wrapped_check_auth(
+                username: str, login_type: str, login_dict: JsonDict
+            ) -> Optional[Tuple[str, Optional[Callable]]]:
+                # We've already made sure f is not None above, but mypy doesn't do well
+                # across function boundaries so we need to tell it f is definitely not
+                # None.
+                assert f is not None
+
+                result = await f(username, login_type, login_dict)
+
+                if isinstance(result, str):
+                    return result, None
+
+                return result
+
+            return wrapped_check_auth
+
+        # We need to wrap check_3pid_auth as in the old form it could return
+        # just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
+        if f.__name__ == "check_3pid_auth":
+
+            async def wrapped_check_3pid_auth(
+                medium: str, address: str, password: str
+            ) -> Optional[Tuple[str, Optional[Callable]]]:
+                # We've already made sure f is not None above, but mypy doesn't do well
+                # across function boundaries so we need to tell it f is definitely not
+                # None.
+                assert f is not None
+
+                result = await f(medium, address, password)
+
+                if isinstance(result, str):
+                    return result, None
+
+                return result
 
-        self._supported_login_types = {}
+            return wrapped_check_3pid_auth
 
-        # grandfather in check_password support
-        if hasattr(self._pp, "check_password"):
-            self._supported_login_types[LoginType.PASSWORD] = ("password",)
+        def run(*args: Tuple, **kwargs: Dict) -> Awaitable:
+            # mypy doesn't do well across function boundaries so we need to tell it
+            # f is definitely not None.
+            assert f is not None
 
-        g = getattr(self._pp, "get_supported_login_types", None)
-        if g:
-            self._supported_login_types.update(g())
+            return maybe_awaitable(f(*args, **kwargs))
 
-    def __str__(self) -> str:
-        return str(self._pp)
+        return run
+
+    # populate hooks with the implemented methods, wrapped with async_wrapper
+    hooks = {
+        hook: async_wrapper(getattr(provider, hook, None))
+        for hook in password_auth_provider_methods
+    }
+
+    supported_login_types = {}
+    # call get_supported_login_types and add that to the dict
+    g = getattr(provider, "get_supported_login_types", None)
+    if g is not None:
+        # Note the old module style also called get_supported_login_types at loading time
+        # and it is synchronous
+        supported_login_types.update(g())
+
+    auth_checkers = {}
+    # Legacy modules have a check_auth method which expects to be called with one of
+    # the keys returned by get_supported_login_types. New style modules register a
+    # dictionary of login_type->check_auth_method mappings
+    check_auth = async_wrapper(getattr(provider, "check_auth", None))
+    if check_auth is not None:
+        for login_type, fields in supported_login_types.items():
+            # need tuple(fields) since fields can be any Iterable type (so may not be hashable)
+            auth_checkers[(login_type, tuple(fields))] = check_auth
+
+    # if it has a "check_password" method then it should handle all auth checks
+    # with login type of LoginType.PASSWORD
+    check_password = async_wrapper(getattr(provider, "check_password", None))
+    if check_password is not None:
+        # need to use a tuple here for ("password",) not a list since lists aren't hashable
+        auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password
+
+    api.register_password_auth_provider_callbacks(hooks, auth_checkers=auth_checkers)
+
+
+CHECK_3PID_AUTH_CALLBACK = Callable[
+    [str, str, str],
+    Awaitable[
+        Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
+    ],
+]
+ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable]
+CHECK_AUTH_CALLBACK = Callable[
+    [str, str, JsonDict],
+    Awaitable[
+        Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
+    ],
+]
+
+
+class PasswordAuthProvider:
+    """
+    A class that the AuthHandler calls when authenticating users
+    It allows modules to provide alternative methods for authentication
+    """
+
+    def __init__(self) -> None:
+        # lists of callbacks
+        self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
+        self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
+
+        # Mapping from login type to login parameters
+        self._supported_login_types: Dict[str, Iterable[str]] = {}
+
+        # Mapping from login type to auth checker callbacks
+        self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}
+
+    def register_password_auth_provider_callbacks(
+        self,
+        check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
+        on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
+        auth_checkers: Optional[Dict[Tuple[str, Tuple], CHECK_AUTH_CALLBACK]] = None,
+    ) -> None:
+        # Register check_3pid_auth callback
+        if check_3pid_auth is not None:
+            self.check_3pid_auth_callbacks.append(check_3pid_auth)
+
+        # register on_logged_out callback
+        if on_logged_out is not None:
+            self.on_logged_out_callbacks.append(on_logged_out)
+
+        if auth_checkers is not None:
+            # register a new supported login_type
+            # Iterate through all of the types being registered
+            for (login_type, fields), callback in auth_checkers.items():
+                # Note: fields may be empty here. This would allow a modules auth checker to
+                # be called with just 'login_type' and no password or other secrets
+
+                # Need to check that all the field names are strings or may get nasty errors later
+                for f in fields:
+                    if not isinstance(f, str):
+                        raise RuntimeError(
+                            "A module tried to register support for login type: %s with parameters %s"
+                            " but all parameter names must be strings"
+                            % (login_type, fields)
+                        )
+
+                # 2 modules supporting the same login type must expect the same fields
+                # e.g. 1 can't expect "pass" if the other expects "password"
+                # so throw an exception if that happens
+                if login_type not in self._supported_login_types.get(login_type, []):
+                    self._supported_login_types[login_type] = fields
+                else:
+                    fields_currently_supported = self._supported_login_types.get(
+                        login_type
+                    )
+                    if fields_currently_supported != fields:
+                        raise RuntimeError(
+                            "A module tried to register support for login type: %s with parameters %s"
+                            " but another module had already registered support for that type with parameters %s"
+                            % (login_type, fields, fields_currently_supported)
+                        )
+
+                # Add the new method to the list of auth_checker_callbacks for this login type
+                self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
 
     def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
         """Get the login types supported by this password provider
@@ -1852,20 +2038,15 @@ class PasswordProvider:
         Returns a map from a login type identifier (such as m.login.password) to an
         iterable giving the fields which must be provided by the user in the submission
         to the /login API.
-
-        This wrapper adds m.login.password to the list if the underlying password
-        provider supports the check_password() api.
         """
+
         return self._supported_login_types
 
     async def check_auth(
         self, username: str, login_type: str, login_dict: JsonDict
-    ) -> Optional[Tuple[str, Optional[Callable]]]:
+    ) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
         """Check if the user has presented valid login credentials
 
-        This wrapper also calls check_password() if the underlying password provider
-        supports the check_password() api and the login type is m.login.password.
-
         Args:
             username: user id presented by the client. Either an MXID or an unqualified
                 username.
@@ -1879,63 +2060,130 @@ class PasswordProvider:
             user, and `callback` is an optional callback which will be called with the
             result from the /login call (including access_token, device_id, etc.)
         """
-        # first grandfather in a call to check_password
-        if login_type == LoginType.PASSWORD:
-            check_password = getattr(self._pp, "check_password", None)
-            if check_password:
-                qualified_user_id = self._module_api.get_qualified_user_id(username)
-                is_valid = await check_password(
-                    qualified_user_id, login_dict["password"]
-                )
-                if is_valid:
-                    return qualified_user_id, None
 
-        check_auth = getattr(self._pp, "check_auth", None)
-        if not check_auth:
-            return None
-        result = await check_auth(username, login_type, login_dict)
+        # Go through all callbacks for the login type until one returns with a value
+        # other than None (i.e. until a callback returns a success)
+        for callback in self.auth_checker_callbacks[login_type]:
+            try:
+                result = await callback(username, login_type, login_dict)
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+                continue
 
-        # Check if the return value is a str or a tuple
-        if isinstance(result, str):
-            # If it's a str, set callback function to None
-            return result, None
+            if result is not None:
+                # Check that the callback returned a Tuple[str, Optional[Callable]]
+                # "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
+                # result is always the right type, but as it is 3rd party code it might not be
+
+                if not isinstance(result, tuple) or len(result) != 2:
+                    logger.warning(
+                        "Wrong type returned by module API callback %s: %s, expected"
+                        " Optional[Tuple[str, Optional[Callable]]]",
+                        callback,
+                        result,
+                    )
+                    continue
 
-        return result
+                # pull out the two parts of the tuple so we can do type checking
+                str_result, callback_result = result
+
+                # the 1st item in the tuple should be a str
+                if not isinstance(str_result, str):
+                    logger.warning(  # type: ignore[unreachable]
+                        "Wrong type returned by module API callback %s: %s, expected"
+                        " Optional[Tuple[str, Optional[Callable]]]",
+                        callback,
+                        result,
+                    )
+                    continue
+
+                # the second should be Optional[Callable]
+                if callback_result is not None:
+                    if not callable(callback_result):
+                        logger.warning(  # type: ignore[unreachable]
+                            "Wrong type returned by module API callback %s: %s, expected"
+                            " Optional[Tuple[str, Optional[Callable]]]",
+                            callback,
+                            result,
+                        )
+                        continue
+
+                # The result is a (str, Optional[callback]) tuple so return the successful result
+                return result
+
+        # If this point has been reached then none of the callbacks successfully authenticated
+        # the user so return None
+        return None
 
     async def check_3pid_auth(
         self, medium: str, address: str, password: str
-    ) -> Optional[Tuple[str, Optional[Callable]]]:
-        g = getattr(self._pp, "check_3pid_auth", None)
-        if not g:
-            return None
-
+    ) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
         # This function is able to return a deferred that either
         # resolves None, meaning authentication failure, or upon
         # success, to a str (which is the user_id) or a tuple of
         # (user_id, callback_func), where callback_func should be run
         # after we've finished everything else
-        result = await g(medium, address, password)
 
-        # Check if the return value is a str or a tuple
-        if isinstance(result, str):
-            # If it's a str, set callback function to None
-            return result, None
+        for callback in self.check_3pid_auth_callbacks:
+            try:
+                result = await callback(medium, address, password)
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+                continue
 
-        return result
+            if result is not None:
+                # Check that the callback returned a Tuple[str, Optional[Callable]]
+                # "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
+                # result is always the right type, but as it is 3rd party code it might not be
+
+                if not isinstance(result, tuple) or len(result) != 2:
+                    logger.warning(
+                        "Wrong type returned by module API callback %s: %s, expected"
+                        " Optional[Tuple[str, Optional[Callable]]]",
+                        callback,
+                        result,
+                    )
+                    continue
+
+                # pull out the two parts of the tuple so we can do type checking
+                str_result, callback_result = result
+
+                # the 1st item in the tuple should be a str
+                if not isinstance(str_result, str):
+                    logger.warning(  # type: ignore[unreachable]
+                        "Wrong type returned by module API callback %s: %s, expected"
+                        " Optional[Tuple[str, Optional[Callable]]]",
+                        callback,
+                        result,
+                    )
+                    continue
+
+                # the second should be Optional[Callable]
+                if callback_result is not None:
+                    if not callable(callback_result):
+                        logger.warning(  # type: ignore[unreachable]
+                            "Wrong type returned by module API callback %s: %s, expected"
+                            " Optional[Tuple[str, Optional[Callable]]]",
+                            callback,
+                            result,
+                        )
+                        continue
+
+                # The result is a (str, Optional[callback]) tuple so return the successful result
+                return result
+
+        # If this point has been reached then none of the callbacks successfully authenticated
+        # the user so return None
+        return None
 
     async def on_logged_out(
         self, user_id: str, device_id: Optional[str], access_token: str
     ) -> None:
-        g = getattr(self._pp, "on_logged_out", None)
-        if not g:
-            return
 
-        # This might return an awaitable, if it does block the log out
-        # until it completes.
-        await maybe_awaitable(
-            g(
-                user_id=user_id,
-                device_id=device_id,
-                access_token=access_token,
-            )
-        )
+        # call all of the on_logged_out callbacks
+        for callback in self.on_logged_out_callbacks:
+            try:
+                callback(user_id, device_id, access_token)
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+                continue