summary refs log tree commit diff
path: root/synapse/handlers/sso.py
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2021-01-05 11:32:29 +0000
committerRichard van der Hoff <richard@matrix.org>2021-01-05 11:32:29 +0000
commit97d12dcf56f6a50d53c80b5d9a69fa66d175c05e (patch)
tree5d0d91e8b737f2b0ba8d691e9efd83ace71d7cb7 /synapse/handlers/sso.py
parentAllow redacting events on workers (#8994) (diff)
parentAdd type hints to the receipts and user directory handlers. (#8976) (diff)
downloadsynapse-97d12dcf56f6a50d53c80b5d9a69fa66d175c05e.tar.xz
Merge remote-tracking branch 'origin/release-v1.25.0' into matrix-org-hotfixes
Diffstat (limited to 'synapse/handlers/sso.py')
-rw-r--r--synapse/handlers/sso.py323
1 files changed, 290 insertions, 33 deletions
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 112a7d5b2c..33cd6bc178 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -13,16 +13,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
 
 import attr
+from typing_extensions import NoReturn
 
 from twisted.web.http import Request
 
-from synapse.api.errors import RedirectException
+from synapse.api.errors import RedirectException, SynapseError
 from synapse.http.server import respond_with_html
-from synapse.types import UserID, contains_invalid_mxid_characters
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
 from synapse.util.async_helpers import Linearizer
+from synapse.util.stringutils import random_string
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -39,16 +42,52 @@ class MappingException(Exception):
 
 @attr.s
 class UserAttributes:
-    localpart = attr.ib(type=str)
+    # the localpart of the mxid that the mapper has assigned to the user.
+    # if `None`, the mapper has not picked a userid, and the user should be prompted to
+    # enter one.
+    localpart = attr.ib(type=Optional[str])
     display_name = attr.ib(type=Optional[str], default=None)
     emails = attr.ib(type=List[str], default=attr.Factory(list))
 
 
+@attr.s(slots=True)
+class UsernameMappingSession:
+    """Data we track about SSO sessions"""
+
+    # A unique identifier for this SSO provider, e.g.  "oidc" or "saml".
+    auth_provider_id = attr.ib(type=str)
+
+    # user ID on the IdP server
+    remote_user_id = attr.ib(type=str)
+
+    # attributes returned by the ID mapper
+    display_name = attr.ib(type=Optional[str])
+    emails = attr.ib(type=List[str])
+
+    # An optional dictionary of extra attributes to be provided to the client in the
+    # login response.
+    extra_login_attributes = attr.ib(type=Optional[JsonDict])
+
+    # where to redirect the client back to
+    client_redirect_url = attr.ib(type=str)
+
+    # expiry time for the session, in milliseconds
+    expiry_time_ms = attr.ib(type=int)
+
+
+# the HTTP cookie used to track the mapping session id
+USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session"
+
+
 class SsoHandler:
     # The number of attempts to ask the mapping provider for when generating an MXID.
     _MAP_USERNAME_RETRIES = 1000
 
+    # the time a UsernameMappingSession remains valid for
+    _MAPPING_SESSION_VALIDITY_PERIOD_MS = 15 * 60 * 1000
+
     def __init__(self, hs: "HomeServer"):
+        self._clock = hs.get_clock()
         self._store = hs.get_datastore()
         self._server_name = hs.hostname
         self._registration_handler = hs.get_registration_handler()
@@ -58,8 +97,15 @@ class SsoHandler:
         # a lock on the mappings
         self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
 
+        # a map from session id to session data
+        self._username_mapping_sessions = {}  # type: Dict[str, UsernameMappingSession]
+
     def render_error(
-        self, request, error: str, error_description: Optional[str] = None
+        self,
+        request: Request,
+        error: str,
+        error_description: Optional[str] = None,
+        code: int = 400,
     ) -> None:
         """Renders the error template and responds with it.
 
@@ -71,11 +117,12 @@ class SsoHandler:
                 We'll respond with an HTML page describing the error.
             error: A technical identifier for this error.
             error_description: A human-readable description of the error.
+            code: The integer error code (an HTTP response code)
         """
         html = self._error_template.render(
             error=error, error_description=error_description
         )
-        respond_with_html(request, 400, html)
+        respond_with_html(request, code, html)
 
     async def get_sso_user_by_remote_user_id(
         self, auth_provider_id: str, remote_user_id: str
@@ -119,15 +166,16 @@ class SsoHandler:
         # No match.
         return None
 
-    async def get_mxid_from_sso(
+    async def complete_sso_login_request(
         self,
         auth_provider_id: str,
         remote_user_id: str,
-        user_agent: str,
-        ip_address: str,
+        request: SynapseRequest,
+        client_redirect_url: str,
         sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
-        grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
-    ) -> str:
+        grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
+        extra_login_attributes: Optional[JsonDict] = None,
+    ) -> None:
         """
         Given an SSO ID, retrieve the user ID for it and possibly register the user.
 
@@ -146,12 +194,18 @@ class SsoHandler:
         given user-agent and IP address and the SSO ID is linked to this matrix
         ID for subsequent calls.
 
+        Finally, we generate a redirect to the supplied redirect uri, with a login token
+
         Args:
             auth_provider_id: A unique identifier for this SSO provider, e.g.
                 "oidc" or "saml".
+
             remote_user_id: The unique identifier from the SSO provider.
-            user_agent: The user agent of the client making the request.
-            ip_address: The IP address of the client making the request.
+
+            request: The request to respond to
+
+            client_redirect_url: The redirect URL passed in by the client.
+
             sso_to_matrix_id_mapper: A callable to generate the user attributes.
                 The only parameter is an integer which represents the amount of
                 times the returned mxid localpart mapping has failed.
@@ -163,12 +217,13 @@ class SsoHandler:
                         to the user.
                     RedirectException to redirect to an additional page (e.g.
                         to prompt the user for more information).
+
             grandfather_existing_users: A callable which can return an previously
                 existing matrix ID. The SSO ID is then linked to the returned
                 matrix ID.
 
-        Returns:
-             The user ID associated with the SSO response.
+            extra_login_attributes: An optional dictionary of extra
+                attributes to be provided to the client in the login response.
 
         Raises:
             MappingException if there was a problem mapping the response to a user.
@@ -181,28 +236,45 @@ class SsoHandler:
         # interstitial pages.
         with await self._mapping_lock.queue(auth_provider_id):
             # first of all, check if we already have a mapping for this user
-            previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
+            user_id = await self.get_sso_user_by_remote_user_id(
                 auth_provider_id, remote_user_id,
             )
-            if previously_registered_user_id:
-                return previously_registered_user_id
 
             # Check for grandfathering of users.
-            if grandfather_existing_users:
-                previously_registered_user_id = await grandfather_existing_users()
-                if previously_registered_user_id:
+            if not user_id:
+                user_id = await grandfather_existing_users()
+                if user_id:
                     # Future logins should also match this user ID.
                     await self._store.record_user_external_id(
-                        auth_provider_id, remote_user_id, previously_registered_user_id
+                        auth_provider_id, remote_user_id, user_id
                     )
-                    return previously_registered_user_id
 
             # Otherwise, generate a new user.
-            attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
-            user_id = await self._register_mapped_user(
-                attributes, auth_provider_id, remote_user_id, user_agent, ip_address,
-            )
-            return user_id
+            if not user_id:
+                attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
+
+                if attributes.localpart is None:
+                    # the mapper doesn't return a username. bail out with a redirect to
+                    # the username picker.
+                    await self._redirect_to_username_picker(
+                        auth_provider_id,
+                        remote_user_id,
+                        attributes,
+                        client_redirect_url,
+                        extra_login_attributes,
+                    )
+
+                user_id = await self._register_mapped_user(
+                    attributes,
+                    auth_provider_id,
+                    remote_user_id,
+                    request.get_user_agent(""),
+                    request.getClientIP(),
+                )
+
+        await self._auth_handler.complete_sso_login(
+            user_id, request, client_redirect_url, extra_login_attributes
+        )
 
     async def _call_attribute_mapper(
         self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
@@ -229,10 +301,8 @@ class SsoHandler:
             )
 
             if not attributes.localpart:
-                raise MappingException(
-                    "Error parsing SSO response: SSO mapping provider plugin "
-                    "did not return a localpart value"
-                )
+                # the mapper has not picked a localpart
+                return attributes
 
             # Check if this mxid already exists
             user_id = UserID(attributes.localpart, self._server_name).to_string()
@@ -247,6 +317,59 @@ class SsoHandler:
             )
         return attributes
 
+    async def _redirect_to_username_picker(
+        self,
+        auth_provider_id: str,
+        remote_user_id: str,
+        attributes: UserAttributes,
+        client_redirect_url: str,
+        extra_login_attributes: Optional[JsonDict],
+    ) -> NoReturn:
+        """Creates a UsernameMappingSession and redirects the browser
+
+        Called if the user mapping provider doesn't return a localpart for a new user.
+        Raises a RedirectException which redirects the browser to the username picker.
+
+        Args:
+            auth_provider_id: A unique identifier for this SSO provider, e.g.
+                "oidc" or "saml".
+
+            remote_user_id: The unique identifier from the SSO provider.
+
+            attributes: the user attributes returned by the user mapping provider.
+
+            client_redirect_url: The redirect URL passed in by the client, which we
+                will eventually redirect back to.
+
+            extra_login_attributes: An optional dictionary of extra
+                attributes to be provided to the client in the login response.
+
+        Raises:
+            RedirectException
+        """
+        session_id = random_string(16)
+        now = self._clock.time_msec()
+        session = UsernameMappingSession(
+            auth_provider_id=auth_provider_id,
+            remote_user_id=remote_user_id,
+            display_name=attributes.display_name,
+            emails=attributes.emails,
+            client_redirect_url=client_redirect_url,
+            expiry_time_ms=now + self._MAPPING_SESSION_VALIDITY_PERIOD_MS,
+            extra_login_attributes=extra_login_attributes,
+        )
+
+        self._username_mapping_sessions[session_id] = session
+        logger.info("Recorded registration session id %s", session_id)
+
+        # Set the cookie and redirect to the username picker
+        e = RedirectException(b"/_synapse/client/pick_username")
+        e.cookies.append(
+            b"%s=%s; path=/"
+            % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii"))
+        )
+        raise e
+
     async def _register_mapped_user(
         self,
         attributes: UserAttributes,
@@ -255,9 +378,38 @@ class SsoHandler:
         user_agent: str,
         ip_address: str,
     ) -> str:
+        """Register a new SSO user.
+
+        This is called once we have successfully mapped the remote user id onto a local
+        user id, one way or another.
+
+        Args:
+             attributes: user attributes returned by the user mapping provider,
+                including a non-empty localpart.
+
+            auth_provider_id: A unique identifier for this SSO provider, e.g.
+                "oidc" or "saml".
+
+            remote_user_id: The unique identifier from the SSO provider.
+
+            user_agent: The user-agent in the HTTP request (used for potential
+                shadow-banning.)
+
+            ip_address: The IP address of the requester (used for potential
+                shadow-banning.)
+
+        Raises:
+            a MappingException if the localpart is invalid.
+
+            a SynapseError with code 400 and errcode Codes.USER_IN_USE if the localpart
+            is already taken.
+        """
+
         # Since the localpart is provided via a potentially untrusted module,
         # ensure the MXID is valid before registering.
-        if contains_invalid_mxid_characters(attributes.localpart):
+        if not attributes.localpart or contains_invalid_mxid_characters(
+            attributes.localpart
+        ):
             raise MappingException("localpart is invalid: %s" % (attributes.localpart,))
 
         logger.debug("Mapped SSO user to local part %s", attributes.localpart)
@@ -312,3 +464,108 @@ class SsoHandler:
         await self._auth_handler.complete_sso_ui_auth(
             user_id, ui_auth_session_id, request
         )
+
+    async def check_username_availability(
+        self, localpart: str, session_id: str,
+    ) -> bool:
+        """Handle an "is username available" callback check
+
+        Args:
+            localpart: desired localpart
+            session_id: the session id for the username picker
+        Returns:
+            True if the username is available
+        Raises:
+            SynapseError if the localpart is invalid or the session is unknown
+        """
+
+        # make sure that there is a valid mapping session, to stop people dictionary-
+        # scanning for accounts
+
+        self._expire_old_sessions()
+        session = self._username_mapping_sessions.get(session_id)
+        if not session:
+            logger.info("Couldn't find session id %s", session_id)
+            raise SynapseError(400, "unknown session")
+
+        logger.info(
+            "[session %s] Checking for availability of username %s",
+            session_id,
+            localpart,
+        )
+
+        if contains_invalid_mxid_characters(localpart):
+            raise SynapseError(400, "localpart is invalid: %s" % (localpart,))
+        user_id = UserID(localpart, self._server_name).to_string()
+        user_infos = await self._store.get_users_by_id_case_insensitive(user_id)
+
+        logger.info("[session %s] users: %s", session_id, user_infos)
+        return not user_infos
+
+    async def handle_submit_username_request(
+        self, request: SynapseRequest, localpart: str, session_id: str
+    ) -> None:
+        """Handle a request to the username-picker 'submit' endpoint
+
+        Will serve an HTTP response to the request.
+
+        Args:
+            request: HTTP request
+            localpart: localpart requested by the user
+            session_id: ID of the username mapping session, extracted from a cookie
+        """
+        self._expire_old_sessions()
+        session = self._username_mapping_sessions.get(session_id)
+        if not session:
+            logger.info("Couldn't find session id %s", session_id)
+            raise SynapseError(400, "unknown session")
+
+        logger.info("[session %s] Registering localpart %s", session_id, localpart)
+
+        attributes = UserAttributes(
+            localpart=localpart,
+            display_name=session.display_name,
+            emails=session.emails,
+        )
+
+        # the following will raise a 400 error if the username has been taken in the
+        # meantime.
+        user_id = await self._register_mapped_user(
+            attributes,
+            session.auth_provider_id,
+            session.remote_user_id,
+            request.get_user_agent(""),
+            request.getClientIP(),
+        )
+
+        logger.info("[session %s] Registered userid %s", session_id, user_id)
+
+        # delete the mapping session and the cookie
+        del self._username_mapping_sessions[session_id]
+
+        # delete the cookie
+        request.addCookie(
+            USERNAME_MAPPING_SESSION_COOKIE_NAME,
+            b"",
+            expires=b"Thu, 01 Jan 1970 00:00:00 GMT",
+            path=b"/",
+        )
+
+        await self._auth_handler.complete_sso_login(
+            user_id,
+            request,
+            session.client_redirect_url,
+            session.extra_login_attributes,
+        )
+
+    def _expire_old_sessions(self):
+        to_expire = []
+        now = int(self._clock.time_msec())
+
+        for session_id, session in self._username_mapping_sessions.items():
+            if session.expiry_time_ms <= now:
+                to_expire.append(session_id)
+
+        for session_id in to_expire:
+            logger.info("Expiring mapping session %s", session_id)
+            del self._username_mapping_sessions[session_id]