summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py288
1 files changed, 184 insertions, 104 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c7dc07008a..a19c556437 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -14,7 +14,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import inspect
 import logging
 import time
 import unicodedata
@@ -22,6 +21,7 @@ import urllib.parse
 from typing import (
     TYPE_CHECKING,
     Any,
+    Awaitable,
     Callable,
     Dict,
     Iterable,
@@ -36,6 +36,8 @@ import attr
 import bcrypt
 import pymacaroons
 
+from twisted.web.http import Request
+
 from synapse.api.constants import LoginType
 from synapse.api.errors import (
     AuthError,
@@ -47,20 +49,25 @@ from synapse.api.errors import (
     UserDeactivatedError,
 )
 from synapse.api.ratelimiting import Ratelimiter
-from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.ui_auth import (
+    INTERACTIVE_AUTH_CHECKERS,
+    UIAuthSessionDataConstants,
+)
 from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
+from synapse.http import get_request_user_agent
 from synapse.http.server import finish_request, respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.module_api import ModuleApi
+from synapse.storage.roommember import ProfileInfo
 from synapse.types import JsonDict, Requester, UserID
 from synapse.util import stringutils as stringutils
+from synapse.util.async_helpers import maybe_awaitable
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.threepids import canonicalise_email
 
-from ._base import BaseHandler
-
 if TYPE_CHECKING:
     from synapse.app.homeserver import HomeServer
 
@@ -193,39 +200,27 @@ class AuthHandler(BaseHandler):
         self.hs = hs  # FIXME better possibility to access registrationHandler later?
         self.macaroon_gen = hs.get_macaroon_generator()
         self._password_enabled = hs.config.password_enabled
-        self._sso_enabled = (
-            hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
-        )
-
-        # we keep this as a list despite the O(N^2) implication so that we can
-        # keep PASSWORD first and avoid confusing clients which pick the first
-        # type in the list. (NB that the spec doesn't require us to do so and
-        # clients which favour types that they don't understand over those that
-        # they do are technically broken)
+        self._password_localdb_enabled = hs.config.password_localdb_enabled
 
         # start out by assuming PASSWORD is enabled; we will remove it later if not.
-        login_types = []
-        if hs.config.password_localdb_enabled:
-            login_types.append(LoginType.PASSWORD)
+        login_types = set()
+        if self._password_localdb_enabled:
+            login_types.add(LoginType.PASSWORD)
 
         for provider in self.password_providers:
-            if hasattr(provider, "get_supported_login_types"):
-                for t in provider.get_supported_login_types().keys():
-                    if t not in login_types:
-                        login_types.append(t)
+            login_types.update(provider.get_supported_login_types().keys())
 
         if not self._password_enabled:
+            login_types.discard(LoginType.PASSWORD)
+
+        # Some clients just pick the first type in the list. In this case, we want
+        # them to use PASSWORD (rather than token or whatever), so we want to make sure
+        # that comes first, where it's present.
+        self._supported_login_types = []
+        if LoginType.PASSWORD in login_types:
+            self._supported_login_types.append(LoginType.PASSWORD)
             login_types.remove(LoginType.PASSWORD)
-
-        self._supported_login_types = login_types
-
-        # Login types and UI Auth types have a heavy overlap, but are not
-        # necessarily identical. Login types have SSO (and other login types)
-        # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
-        ui_auth_types = login_types.copy()
-        if self._sso_enabled:
-            ui_auth_types.append(LoginType.SSO)
-        self._supported_ui_auth_types = ui_auth_types
+        self._supported_login_types.extend(login_types)
 
         # Ratelimiter for failed auth during UIA. Uses same ratelimit config
         # as per `rc_login.failed_attempts`.
@@ -235,6 +230,9 @@ class AuthHandler(BaseHandler):
             burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
         )
 
+        # The number of seconds to keep a UI auth session active.
+        self._ui_auth_session_timeout = hs.config.ui_auth_session_timeout
+
         # Ratelimitier for failed /login attempts
         self._failed_login_attempts_ratelimiter = Ratelimiter(
             clock=hs.get_clock(),
@@ -266,10 +264,6 @@ class AuthHandler(BaseHandler):
         # authenticating for an operation to occur on their account.
         self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template
 
-        # The following template is shown after a successful user interactive
-        # authentication session. It tells the user they can close the window.
-        self._sso_auth_success_template = hs.config.sso_auth_success_template
-
         # The following template is shown during the SSO authentication process if
         # the account is deactivated.
         self._sso_account_deactivated_template = (
@@ -290,9 +284,8 @@ class AuthHandler(BaseHandler):
         requester: Requester,
         request: SynapseRequest,
         request_body: Dict[str, Any],
-        clientip: str,
         description: str,
-    ) -> Tuple[dict, str]:
+    ) -> Tuple[dict, Optional[str]]:
         """
         Checks that the user is who they claim to be, via a UI auth.
 
@@ -307,8 +300,6 @@ class AuthHandler(BaseHandler):
 
             request_body: The body of the request sent by the client
 
-            clientip: The IP address of the client.
-
             description: A human readable string to be displayed to the user that
                          describes the operation happening on their account.
 
@@ -319,7 +310,8 @@ class AuthHandler(BaseHandler):
                 have been given only in a previous call).
 
                 'session_id' is the ID of this session, either passed in by the
-                client or assigned by this call
+                client or assigned by this call. This is None if UI auth was
+                skipped (by re-using a previous validation).
 
         Raises:
             InteractiveAuthIncompleteError if the client has not yet completed
@@ -333,40 +325,90 @@ class AuthHandler(BaseHandler):
 
         """
 
-        user_id = requester.user.to_string()
+        if self._ui_auth_session_timeout:
+            last_validated = await self.store.get_access_token_last_validated(
+                requester.access_token_id
+            )
+            if self.clock.time_msec() - last_validated < self._ui_auth_session_timeout:
+                # Return the input parameters, minus the auth key, which matches
+                # the logic in check_ui_auth.
+                request_body.pop("auth", None)
+                return request_body, None
+
+        requester_user_id = requester.user.to_string()
 
         # Check if we should be ratelimited due to too many previous failed attempts
-        self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
+        self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False)
 
         # build a list of supported flows
-        flows = [[login_type] for login_type in self._supported_ui_auth_types]
+        supported_ui_auth_types = await self._get_available_ui_auth_types(
+            requester.user
+        )
+        flows = [[login_type] for login_type in supported_ui_auth_types]
+
+        def get_new_session_data() -> JsonDict:
+            return {UIAuthSessionDataConstants.REQUEST_USER_ID: requester_user_id}
 
         try:
             result, params, session_id = await self.check_ui_auth(
-                flows, request, request_body, clientip, description
+                flows, request, request_body, description, get_new_session_data,
             )
         except LoginError:
             # Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
-            self._failed_uia_attempts_ratelimiter.can_do_action(user_id)
+            self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id)
             raise
 
         # find the completed login type
-        for login_type in self._supported_ui_auth_types:
+        for login_type in supported_ui_auth_types:
             if login_type not in result:
                 continue
 
-            user_id = result[login_type]
+            validated_user_id = result[login_type]
             break
         else:
             # this can't happen
             raise Exception("check_auth returned True but no successful login type")
 
         # check that the UI auth matched the access token
-        if user_id != requester.user.to_string():
+        if validated_user_id != requester_user_id:
             raise AuthError(403, "Invalid auth")
 
+        # Note that the access token has been validated.
+        await self.store.update_access_token_last_validated(requester.access_token_id)
+
         return params, session_id
 
+    async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
+        """Get a list of the authentication types this user can use
+        """
+
+        ui_auth_types = set()
+
+        # if the HS supports password auth, and the user has a non-null password, we
+        # support password auth
+        if self._password_localdb_enabled and self._password_enabled:
+            lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
+            if lookupres:
+                _, password_hash = lookupres
+                if password_hash:
+                    ui_auth_types.add(LoginType.PASSWORD)
+
+        # also allow auth from password providers
+        for provider in self.password_providers:
+            for t in provider.get_supported_login_types().keys():
+                if t == LoginType.PASSWORD and not self._password_enabled:
+                    continue
+                ui_auth_types.add(t)
+
+        # if sso is enabled, allow the user to log in via SSO iff they have a mapping
+        # from sso to mxid.
+        if await self.hs.get_sso_handler().get_identity_providers_for_user(
+            user.to_string()
+        ):
+            ui_auth_types.add(LoginType.SSO)
+
+        return ui_auth_types
+
     def get_enabled_auth_types(self):
         """Return the enabled user-interactive authentication types
 
@@ -380,8 +422,8 @@ class AuthHandler(BaseHandler):
         flows: List[List[str]],
         request: SynapseRequest,
         clientdict: Dict[str, Any],
-        clientip: str,
         description: str,
+        get_new_session_data: Optional[Callable[[], JsonDict]] = None,
     ) -> Tuple[dict, dict, str]:
         """
         Takes a dictionary sent by the client in the login / registration
@@ -402,11 +444,16 @@ class AuthHandler(BaseHandler):
             clientdict: The dictionary from the client root level, not the
                         'auth' key: this method prompts for auth if none is sent.
 
-            clientip: The IP address of the client.
-
             description: A human readable string to be displayed to the user that
                          describes the operation happening on their account.
 
+            get_new_session_data:
+                an optional callback which will be called when starting a new session.
+                it should return data to be stored as part of the session.
+
+                The keys of the returned data should be entries in
+                UIAuthSessionDataConstants.
+
         Returns:
             A tuple of (creds, params, session_id).
 
@@ -423,13 +470,10 @@ class AuthHandler(BaseHandler):
                 all the stages in any of the permitted flows.
         """
 
-        authdict = None
         sid = None  # type: Optional[str]
-        if clientdict and "auth" in clientdict:
-            authdict = clientdict["auth"]
-            del clientdict["auth"]
-            if "session" in authdict:
-                sid = authdict["session"]
+        authdict = clientdict.pop("auth", {})
+        if "session" in authdict:
+            sid = authdict["session"]
 
         # Convert the URI and method to strings.
         uri = request.uri.decode("utf-8")
@@ -437,10 +481,15 @@ class AuthHandler(BaseHandler):
 
         # If there's no session ID, create a new session.
         if not sid:
+            new_session_data = get_new_session_data() if get_new_session_data else {}
+
             session = await self.store.create_ui_auth_session(
                 clientdict, uri, method, description
             )
 
+            for k, v in new_session_data.items():
+                await self.set_session_data(session.session_id, k, v)
+
         else:
             try:
                 session = await self.store.get_ui_auth_session(sid)
@@ -496,7 +545,8 @@ class AuthHandler(BaseHandler):
             # authentication flow.
             await self.store.set_ui_auth_clientdict(sid, clientdict)
 
-        user_agent = request.get_user_agent("")
+        user_agent = get_request_user_agent(request)
+        clientip = request.getClientIP()
 
         await self.store.add_user_agent_ip_to_ui_auth_session(
             session.session_id, user_agent, clientip
@@ -518,22 +568,14 @@ class AuthHandler(BaseHandler):
                         session.session_id, login_type, result
                     )
             except LoginError as e:
-                if login_type == LoginType.EMAIL_IDENTITY:
-                    # riot used to have a bug where it would request a new
-                    # validation token (thus sending a new email) each time it
-                    # got a 401 with a 'flows' field.
-                    # (https://github.com/vector-im/vector-web/issues/2447).
-                    #
-                    # Grandfather in the old behaviour for now to avoid
-                    # breaking old riot deployments.
-                    raise
-
                 # this step failed. Merge the error dict into the response
                 # so that the client can have another go.
                 errordict = e.error_dict()
 
         creds = await self.store.get_completed_ui_auth_stages(session.session_id)
         for f in flows:
+            # If all the required credentials have been supplied, the user has
+            # successfully completed the UI auth process!
             if len(set(f) - set(creds)) == 0:
                 # it's very useful to know what args are stored, but this can
                 # include the password in the case of registering, so only log
@@ -599,7 +641,8 @@ class AuthHandler(BaseHandler):
 
         Args:
             session_id: The ID of this session as returned from check_auth
-            key: The key to store the data under
+            key: The key to store the data under. An entry from
+                UIAuthSessionDataConstants.
             value: The data to store
         """
         try:
@@ -615,7 +658,8 @@ class AuthHandler(BaseHandler):
 
         Args:
             session_id: The ID of this session as returned from check_auth
-            key: The key to store the data under
+            key: The key the data was stored under. An entry from
+                UIAuthSessionDataConstants.
             default: Value to return if the key has not been set
         """
         try:
@@ -709,6 +753,7 @@ class AuthHandler(BaseHandler):
         device_id: Optional[str],
         valid_until_ms: Optional[int],
         puppets_user_id: Optional[str] = None,
+        is_appservice_ghost: bool = False,
     ) -> str:
         """
         Creates a new access token for the user with the given user ID.
@@ -725,6 +770,7 @@ class AuthHandler(BaseHandler):
                we should always have a device ID)
             valid_until_ms: when the token is valid until. None for
                 no expiry.
+            is_appservice_ghost: Whether the user is an application ghost user
         Returns:
               The access token for the user's session.
         Raises:
@@ -745,7 +791,11 @@ class AuthHandler(BaseHandler):
                 "Logging in user %s on device %s%s", user_id, device_id, fmt_expiry
             )
 
-        await self.auth.check_auth_blocking(user_id)
+        if (
+            not is_appservice_ghost
+            or self.hs.config.appservice.track_appservice_user_ips
+        ):
+            await self.auth.check_auth_blocking(user_id)
 
         access_token = self.macaroon_gen.generate_access_token(user_id)
         await self.store.add_access_token_to_user(
@@ -831,7 +881,7 @@ class AuthHandler(BaseHandler):
 
     async def validate_login(
         self, login_submission: Dict[str, Any], ratelimit: bool = False,
-    ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
+    ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
         """Authenticates the user for the /login API
 
         Also used by the user-interactive auth flow to validate auth types which don't
@@ -974,7 +1024,7 @@ class AuthHandler(BaseHandler):
 
     async def _validate_userid_login(
         self, username: str, login_submission: Dict[str, Any],
-    ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
+    ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
         """Helper for validate_login
 
         Handles login, once we've mapped 3pids onto userids
@@ -1029,7 +1079,7 @@ class AuthHandler(BaseHandler):
             if result:
                 return result
 
-        if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
+        if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
             known_login_type = True
 
             # we've already checked that there is a (valid) password field
@@ -1052,7 +1102,7 @@ class AuthHandler(BaseHandler):
 
     async def check_password_provider_3pid(
         self, medium: str, address: str, password: str
-    ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]:
+    ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
         """Check if a password provider is able to validate a thirdparty login
 
         Args:
@@ -1283,12 +1333,12 @@ class AuthHandler(BaseHandler):
         else:
             return False
 
-    async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
+    async def start_sso_ui_auth(self, request: SynapseRequest, session_id: str) -> str:
         """
         Get the HTML for the SSO redirect confirmation page.
 
         Args:
-            redirect_url: The URL to redirect to the SSO provider.
+            request: The incoming HTTP request
             session_id: The user interactive authentication session ID.
 
         Returns:
@@ -1298,38 +1348,48 @@ class AuthHandler(BaseHandler):
             session = await self.store.get_ui_auth_session(session_id)
         except StoreError:
             raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
-        return self._sso_auth_confirm_template.render(
-            description=session.description, redirect_url=redirect_url,
+
+        user_id_to_verify = await self.get_session_data(
+            session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
+        )  # type: str
+
+        idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
+            user_id_to_verify
         )
 
-    async def complete_sso_ui_auth(
-        self, registered_user_id: str, session_id: str, request: SynapseRequest,
-    ):
-        """Having figured out a mxid for this user, complete the HTTP request
+        if not idps:
+            # we checked that the user had some remote identities before offering an SSO
+            # flow, so either it's been deleted or the client has requested SSO despite
+            # it not being offered.
+            raise SynapseError(400, "User has no SSO identities")
 
-        Args:
-            registered_user_id: The registered user ID to complete SSO login for.
-            request: The request to complete.
-            client_redirect_url: The URL to which to redirect the user at the end of the
-                process.
-        """
-        # Mark the stage of the authentication as successful.
-        # Save the user who authenticated with SSO, this will be used to ensure
-        # that the account be modified is also the person who logged in.
-        await self.store.mark_ui_auth_stage_complete(
-            session_id, LoginType.SSO, registered_user_id
+        # for now, just pick one
+        idp_id, sso_auth_provider = next(iter(idps.items()))
+        if len(idps) > 0:
+            logger.warning(
+                "User %r has previously logged in with multiple SSO IdPs; arbitrarily "
+                "picking %r",
+                user_id_to_verify,
+                idp_id,
+            )
+
+        redirect_url = await sso_auth_provider.handle_redirect_request(
+            request, None, session_id
         )
 
-        # Render the HTML and return.
-        html = self._sso_auth_success_template
-        respond_with_html(request, 200, html)
+        return self._sso_auth_confirm_template.render(
+            description=session.description,
+            redirect_url=redirect_url,
+            idp=sso_auth_provider,
+        )
 
     async def complete_sso_login(
         self,
         registered_user_id: str,
-        request: SynapseRequest,
+        request: Request,
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
+        new_user: bool = False,
     ):
         """Having figured out a mxid for this user, complete the HTTP request
 
@@ -1340,6 +1400,8 @@ class AuthHandler(BaseHandler):
                 process.
             extra_attributes: Extra attributes which will be passed to the client
                 during successful login. Must be JSON serializable.
+            new_user: True if we should use wording appropriate to a user who has just
+                registered.
         """
         # If the account has been deactivated, do not proceed with the login
         # flow.
@@ -1348,22 +1410,37 @@ class AuthHandler(BaseHandler):
             respond_with_html(request, 403, self._sso_account_deactivated_template)
             return
 
+        profile = await self.store.get_profileinfo(
+            UserID.from_string(registered_user_id).localpart
+        )
+
         self._complete_sso_login(
-            registered_user_id, request, client_redirect_url, extra_attributes
+            registered_user_id,
+            request,
+            client_redirect_url,
+            extra_attributes,
+            new_user=new_user,
+            user_profile_data=profile,
         )
 
     def _complete_sso_login(
         self,
         registered_user_id: str,
-        request: SynapseRequest,
+        request: Request,
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
+        new_user: bool = False,
+        user_profile_data: Optional[ProfileInfo] = None,
     ):
         """
         The synchronous portion of complete_sso_login.
 
         This exists purely for backwards compatibility of synapse.module_api.ModuleApi.
         """
+
+        if user_profile_data is None:
+            user_profile_data = ProfileInfo(None, None)
+
         # Store any extra attributes which will be passed in the login response.
         # Note that this is per-user so it may overwrite a previous value, this
         # is considered OK since the newest SSO attributes should be most valid.
@@ -1401,6 +1478,9 @@ class AuthHandler(BaseHandler):
             display_url=redirect_url_no_params,
             redirect_url=redirect_url,
             server_name=self._server_name,
+            new_user=new_user,
+            user_id=registered_user_id,
+            user_profile=user_profile_data,
         )
         respond_with_html(request, 200, html)
 
@@ -1438,8 +1518,8 @@ class AuthHandler(BaseHandler):
     @staticmethod
     def add_query_param_to_url(url: str, param_name: str, param: Any):
         url_parts = list(urllib.parse.urlparse(url))
-        query = dict(urllib.parse.parse_qsl(url_parts[4]))
-        query.update({param_name: param})
+        query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
+        query.append((param_name, param))
         url_parts[4] = urllib.parse.urlencode(query)
         return urllib.parse.urlunparse(url_parts)
 
@@ -1609,6 +1689,6 @@ class PasswordProvider:
 
         # This might return an awaitable, if it does block the log out
         # until it completes.
-        result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
-        if inspect.isawaitable(result):
-            await result
+        await maybe_awaitable(
+            g(user_id=user_id, device_id=device_id, access_token=access_token,)
+        )