summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2020-12-02 10:38:50 +0000
committerGitHub <noreply@github.com>2020-12-02 10:38:50 +0000
commitd3ed93504bb6bb8ad138e356e3c74b6a7286299b (patch)
treed8742391d03f5adfed56526e5170befc60078d33
parentAllow specifying room version in 'RestHelper.create_room_as' and add typing (... (diff)
downloadsynapse-d3ed93504bb6bb8ad138e356e3c74b6a7286299b.tar.xz
Create a `PasswordProvider` wrapper object (#8849)
The idea here is to abstract out all the conditional code which tests which
methods a given password provider has, to provide a consistent interface.
-rw-r--r--changelog.d/8849.misc1
-rw-r--r--synapse/handlers/auth.py203
-rw-r--r--tests/handlers/test_password_providers.py5
3 files changed, 152 insertions, 57 deletions
diff --git a/changelog.d/8849.misc b/changelog.d/8849.misc
new file mode 100644
index 0000000000..3dd496ce61
--- /dev/null
+++ b/changelog.d/8849.misc
@@ -0,0 +1 @@
+Refactor `password_auth_provider` support code.
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 8815f685b9..c7dc07008a 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1,6 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014 - 2016 OpenMarket Ltd
 # Copyright 2017 Vector Creations Ltd
+# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -25,6 +26,7 @@ from typing import (
     Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
     Tuple,
     Union,
@@ -181,17 +183,12 @@ class AuthHandler(BaseHandler):
         #   better way to break the loop
         account_handler = ModuleApi(hs, self)
 
-        self.password_providers = []
-        for module, config in hs.config.password_providers:
-            try:
-                self.password_providers.append(
-                    module(config=config, account_handler=account_handler)
-                )
-            except Exception as e:
-                logger.error("Error while initializing %r: %s", module, e)
-                raise
+        self.password_providers = [
+            PasswordProvider.load(module, config, account_handler)
+            for module, config in hs.config.password_providers
+        ]
 
-        logger.info("Extra password_providers: %r", self.password_providers)
+        logger.info("Extra password_providers: %s", self.password_providers)
 
         self.hs = hs  # FIXME better possibility to access registrationHandler later?
         self.macaroon_gen = hs.get_macaroon_generator()
@@ -853,6 +850,8 @@ class AuthHandler(BaseHandler):
             LoginError if there was an authentication problem.
         """
         login_type = login_submission.get("type")
+        if not isinstance(login_type, str):
+            raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM)
 
         # ideally, we wouldn't be checking the identifier unless we know we have a login
         # method which uses it (https://github.com/matrix-org/synapse/issues/8836)
@@ -998,24 +997,12 @@ class AuthHandler(BaseHandler):
             qualified_user_id = UserID(username, self.hs.hostname).to_string()
 
         login_type = login_submission.get("type")
+        # we already checked that we have a valid login type
+        assert isinstance(login_type, str)
+
         known_login_type = False
 
         for provider in self.password_providers:
-            if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
-                known_login_type = True
-                # we've already checked that there is a (valid) password field
-                is_valid = await provider.check_password(
-                    qualified_user_id, login_submission["password"]
-                )
-                if is_valid:
-                    return qualified_user_id, None
-
-            if not hasattr(provider, "get_supported_login_types") or not hasattr(
-                provider, "check_auth"
-            ):
-                # this password provider doesn't understand custom login types
-                continue
-
             supported_login_types = provider.get_supported_login_types()
             if login_type not in supported_login_types:
                 # this password provider doesn't understand this login type
@@ -1040,8 +1027,6 @@ class AuthHandler(BaseHandler):
 
             result = await provider.check_auth(username, login_type, login_dict)
             if result:
-                if isinstance(result, str):
-                    result = (result, None)
                 return result
 
         if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
@@ -1083,19 +1068,9 @@ class AuthHandler(BaseHandler):
             unsuccessful, `user_id` and `callback` are both `None`.
         """
         for provider in self.password_providers:
-            if hasattr(provider, "check_3pid_auth"):
-                # This function is able to return a deferred that either
-                # resolves None, meaning authentication failure, or upon
-                # success, to a str (which is the user_id) or a tuple of
-                # (user_id, callback_func), where callback_func should be run
-                # after we've finished everything else
-                result = await provider.check_3pid_auth(medium, address, password)
-                if result:
-                    # Check if the return value is a str or a tuple
-                    if isinstance(result, str):
-                        # If it's a str, set callback function to None
-                        result = (result, None)
-                    return result
+            result = await provider.check_3pid_auth(medium, address, password)
+            if result:
+                return result
 
         return None, None
 
@@ -1153,16 +1128,11 @@ class AuthHandler(BaseHandler):
 
         # see if any of our auth providers want to know about this
         for provider in self.password_providers:
-            if hasattr(provider, "on_logged_out"):
-                # This might return an awaitable, if it does block the log out
-                # until it completes.
-                result = provider.on_logged_out(
-                    user_id=user_info.user_id,
-                    device_id=user_info.device_id,
-                    access_token=access_token,
-                )
-                if inspect.isawaitable(result):
-                    await result
+            await provider.on_logged_out(
+                user_id=user_info.user_id,
+                device_id=user_info.device_id,
+                access_token=access_token,
+            )
 
         # delete pushers associated with this access token
         if user_info.token_id is not None:
@@ -1191,11 +1161,10 @@ class AuthHandler(BaseHandler):
 
         # see if any of our auth providers want to know about this
         for provider in self.password_providers:
-            if hasattr(provider, "on_logged_out"):
-                for token, token_id, device_id in tokens_and_devices:
-                    await provider.on_logged_out(
-                        user_id=user_id, device_id=device_id, access_token=token
-                    )
+            for token, token_id, device_id in tokens_and_devices:
+                await provider.on_logged_out(
+                    user_id=user_id, device_id=device_id, access_token=token
+                )
 
         # delete pushers associated with the access tokens
         await self.hs.get_pusherpool().remove_pushers_by_access_token(
@@ -1519,3 +1488,127 @@ class MacaroonGenerator:
         macaroon.add_first_party_caveat("gen = 1")
         macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
         return macaroon
+
+
+class PasswordProvider:
+    """Wrapper for a password auth provider module
+
+    This class abstracts out all of the backwards-compatibility hacks for
+    password providers, to provide a consistent interface.
+    """
+
+    @classmethod
+    def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
+        try:
+            pp = module(config=config, account_handler=module_api)
+        except Exception as e:
+            logger.error("Error while initializing %r: %s", module, e)
+            raise
+        return cls(pp, module_api)
+
+    def __init__(self, pp, module_api: ModuleApi):
+        self._pp = pp
+        self._module_api = module_api
+
+        self._supported_login_types = {}
+
+        # grandfather in check_password support
+        if hasattr(self._pp, "check_password"):
+            self._supported_login_types[LoginType.PASSWORD] = ("password",)
+
+        g = getattr(self._pp, "get_supported_login_types", None)
+        if g:
+            self._supported_login_types.update(g())
+
+    def __str__(self):
+        return str(self._pp)
+
+    def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
+        """Get the login types supported by this password provider
+
+        Returns a map from a login type identifier (such as m.login.password) to an
+        iterable giving the fields which must be provided by the user in the submission
+        to the /login API.
+
+        This wrapper adds m.login.password to the list if the underlying password
+        provider supports the check_password() api.
+        """
+        return self._supported_login_types
+
+    async def check_auth(
+        self, username: str, login_type: str, login_dict: JsonDict
+    ) -> Optional[Tuple[str, Optional[Callable]]]:
+        """Check if the user has presented valid login credentials
+
+        This wrapper also calls check_password() if the underlying password provider
+        supports the check_password() api and the login type is m.login.password.
+
+        Args:
+            username: user id presented by the client. Either an MXID or an unqualified
+                username.
+
+            login_type: the login type being attempted - one of the types returned by
+                get_supported_login_types()
+
+            login_dict: the dictionary of login secrets passed by the client.
+
+        Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the
+            user, and `callback` is an optional callback which will be called with the
+            result from the /login call (including access_token, device_id, etc.)
+        """
+        # first grandfather in a call to check_password
+        if login_type == LoginType.PASSWORD:
+            g = getattr(self._pp, "check_password", None)
+            if g:
+                qualified_user_id = self._module_api.get_qualified_user_id(username)
+                is_valid = await self._pp.check_password(
+                    qualified_user_id, login_dict["password"]
+                )
+                if is_valid:
+                    return qualified_user_id, None
+
+        g = getattr(self._pp, "check_auth", None)
+        if not g:
+            return None
+        result = await g(username, login_type, login_dict)
+
+        # Check if the return value is a str or a tuple
+        if isinstance(result, str):
+            # If it's a str, set callback function to None
+            return result, None
+
+        return result
+
+    async def check_3pid_auth(
+        self, medium: str, address: str, password: str
+    ) -> Optional[Tuple[str, Optional[Callable]]]:
+        g = getattr(self._pp, "check_3pid_auth", None)
+        if not g:
+            return None
+
+        # This function is able to return a deferred that either
+        # resolves None, meaning authentication failure, or upon
+        # success, to a str (which is the user_id) or a tuple of
+        # (user_id, callback_func), where callback_func should be run
+        # after we've finished everything else
+        result = await g(medium, address, password)
+
+        # Check if the return value is a str or a tuple
+        if isinstance(result, str):
+            # If it's a str, set callback function to None
+            return result, None
+
+        return result
+
+    async def on_logged_out(
+        self, user_id: str, device_id: Optional[str], access_token: str
+    ) -> None:
+        g = getattr(self._pp, "on_logged_out", None)
+        if not g:
+            return
+
+        # This might return an awaitable, if it does block the log out
+        # until it completes.
+        result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
+        if inspect.isawaitable(result):
+            await result
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 22b9a11dc0..ceaf0902d2 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -266,8 +266,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         # first delete should give a 401
         channel = self._delete_device(tok1, "dev2")
         self.assertEqual(channel.code, 401)
-        # there are no valid flows here!
-        self.assertEqual(channel.json_body["flows"], [])
+        # m.login.password UIA is permitted because the auth provider allows it,
+        # even though the localdb does not.
+        self.assertEqual(channel.json_body["flows"], [{"stages": ["m.login.password"]}])
         session = channel.json_body["session"]
         mock_password_provider.check_password.assert_not_called()