summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/oidc_handler.py59
-rw-r--r--synapse/handlers/sso.py254
2 files changed, 273 insertions, 40 deletions
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index cbd11a1382..709f8dfc13 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -947,7 +947,7 @@ class OidcHandler(BaseHandler):
 
 
 UserAttributeDict = TypedDict(
-    "UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
+    "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
 )
 C = TypeVar("C")
 
@@ -1028,10 +1028,10 @@ env = Environment(finalize=jinja_finalize)
 
 @attr.s
 class JinjaOidcMappingConfig:
-    subject_claim = attr.ib()  # type: str
-    localpart_template = attr.ib()  # type: Template
-    display_name_template = attr.ib()  # type: Optional[Template]
-    extra_attributes = attr.ib()  # type: Dict[str, Template]
+    subject_claim = attr.ib(type=str)
+    localpart_template = attr.ib(type=Optional[Template])
+    display_name_template = attr.ib(type=Optional[Template])
+    extra_attributes = attr.ib(type=Dict[str, Template])
 
 
 class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
@@ -1047,18 +1047,14 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
     def parse_config(config: dict) -> JinjaOidcMappingConfig:
         subject_claim = config.get("subject_claim", "sub")
 
-        if "localpart_template" not in config:
-            raise ConfigError(
-                "missing key: oidc_config.user_mapping_provider.config.localpart_template"
-            )
-
-        try:
-            localpart_template = env.from_string(config["localpart_template"])
-        except Exception as e:
-            raise ConfigError(
-                "invalid jinja template for oidc_config.user_mapping_provider.config.localpart_template: %r"
-                % (e,)
-            )
+        localpart_template = None  # type: Optional[Template]
+        if "localpart_template" in config:
+            try:
+                localpart_template = env.from_string(config["localpart_template"])
+            except Exception as e:
+                raise ConfigError(
+                    "invalid jinja template", path=["localpart_template"]
+                ) from e
 
         display_name_template = None  # type: Optional[Template]
         if "display_name_template" in config:
@@ -1066,26 +1062,22 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
                 display_name_template = env.from_string(config["display_name_template"])
             except Exception as e:
                 raise ConfigError(
-                    "invalid jinja template for oidc_config.user_mapping_provider.config.display_name_template: %r"
-                    % (e,)
-                )
+                    "invalid jinja template", path=["display_name_template"]
+                ) from e
 
         extra_attributes = {}  # type Dict[str, Template]
         if "extra_attributes" in config:
             extra_attributes_config = config.get("extra_attributes") or {}
             if not isinstance(extra_attributes_config, dict):
-                raise ConfigError(
-                    "oidc_config.user_mapping_provider.config.extra_attributes must be a dict"
-                )
+                raise ConfigError("must be a dict", path=["extra_attributes"])
 
             for key, value in extra_attributes_config.items():
                 try:
                     extra_attributes[key] = env.from_string(value)
                 except Exception as e:
                     raise ConfigError(
-                        "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r"
-                        % (key, e)
-                    )
+                        "invalid jinja template", path=["extra_attributes", key]
+                    ) from e
 
         return JinjaOidcMappingConfig(
             subject_claim=subject_claim,
@@ -1100,14 +1092,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
     async def map_user_attributes(
         self, userinfo: UserInfo, token: Token, failures: int
     ) -> UserAttributeDict:
-        localpart = self._config.localpart_template.render(user=userinfo).strip()
+        localpart = None
+
+        if self._config.localpart_template:
+            localpart = self._config.localpart_template.render(user=userinfo).strip()
 
-        # Ensure only valid characters are included in the MXID.
-        localpart = map_username_to_mxid_localpart(localpart)
+            # Ensure only valid characters are included in the MXID.
+            localpart = map_username_to_mxid_localpart(localpart)
 
-        # Append suffix integer if last call to this function failed to produce
-        # a usable mxid.
-        localpart += str(failures) if failures else ""
+            # Append suffix integer if last call to this function failed to produce
+            # a usable mxid.
+            localpart += str(failures) if failures else ""
 
         display_name = None  # type: Optional[str]
         if self._config.display_name_template is not None:
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index f054b66a53..548b02211b 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -13,17 +13,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
 
 import attr
+from typing_extensions import NoReturn
 
 from twisted.web.http import Request
 
-from synapse.api.errors import RedirectException
+from synapse.api.errors import RedirectException, SynapseError
 from synapse.http.server import respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
 from synapse.util.async_helpers import Linearizer
+from synapse.util.stringutils import random_string
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -40,16 +42,52 @@ class MappingException(Exception):
 
 @attr.s
 class UserAttributes:
-    localpart = attr.ib(type=str)
+    # the localpart of the mxid that the mapper has assigned to the user.
+    # if `None`, the mapper has not picked a userid, and the user should be prompted to
+    # enter one.
+    localpart = attr.ib(type=Optional[str])
     display_name = attr.ib(type=Optional[str], default=None)
     emails = attr.ib(type=List[str], default=attr.Factory(list))
 
 
+@attr.s(slots=True)
+class UsernameMappingSession:
+    """Data we track about SSO sessions"""
+
+    # A unique identifier for this SSO provider, e.g.  "oidc" or "saml".
+    auth_provider_id = attr.ib(type=str)
+
+    # user ID on the IdP server
+    remote_user_id = attr.ib(type=str)
+
+    # attributes returned by the ID mapper
+    display_name = attr.ib(type=Optional[str])
+    emails = attr.ib(type=List[str])
+
+    # An optional dictionary of extra attributes to be provided to the client in the
+    # login response.
+    extra_login_attributes = attr.ib(type=Optional[JsonDict])
+
+    # where to redirect the client back to
+    client_redirect_url = attr.ib(type=str)
+
+    # expiry time for the session, in milliseconds
+    expiry_time_ms = attr.ib(type=int)
+
+
+# the HTTP cookie used to track the mapping session id
+USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session"
+
+
 class SsoHandler:
     # The number of attempts to ask the mapping provider for when generating an MXID.
     _MAP_USERNAME_RETRIES = 1000
 
+    # the time a UsernameMappingSession remains valid for
+    _MAPPING_SESSION_VALIDITY_PERIOD_MS = 15 * 60 * 1000
+
     def __init__(self, hs: "HomeServer"):
+        self._clock = hs.get_clock()
         self._store = hs.get_datastore()
         self._server_name = hs.hostname
         self._registration_handler = hs.get_registration_handler()
@@ -59,6 +97,9 @@ class SsoHandler:
         # a lock on the mappings
         self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
 
+        # a map from session id to session data
+        self._username_mapping_sessions = {}  # type: Dict[str, UsernameMappingSession]
+
     def render_error(
         self, request, error: str, error_description: Optional[str] = None
     ) -> None:
@@ -206,6 +247,18 @@ class SsoHandler:
             # Otherwise, generate a new user.
             if not user_id:
                 attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
+
+                if attributes.localpart is None:
+                    # the mapper doesn't return a username. bail out with a redirect to
+                    # the username picker.
+                    await self._redirect_to_username_picker(
+                        auth_provider_id,
+                        remote_user_id,
+                        attributes,
+                        client_redirect_url,
+                        extra_login_attributes,
+                    )
+
                 user_id = await self._register_mapped_user(
                     attributes,
                     auth_provider_id,
@@ -243,10 +296,8 @@ class SsoHandler:
             )
 
             if not attributes.localpart:
-                raise MappingException(
-                    "Error parsing SSO response: SSO mapping provider plugin "
-                    "did not return a localpart value"
-                )
+                # the mapper has not picked a localpart
+                return attributes
 
             # Check if this mxid already exists
             user_id = UserID(attributes.localpart, self._server_name).to_string()
@@ -261,6 +312,59 @@ class SsoHandler:
             )
         return attributes
 
+    async def _redirect_to_username_picker(
+        self,
+        auth_provider_id: str,
+        remote_user_id: str,
+        attributes: UserAttributes,
+        client_redirect_url: str,
+        extra_login_attributes: Optional[JsonDict],
+    ) -> NoReturn:
+        """Creates a UsernameMappingSession and redirects the browser
+
+        Called if the user mapping provider doesn't return a localpart for a new user.
+        Raises a RedirectException which redirects the browser to the username picker.
+
+        Args:
+            auth_provider_id: A unique identifier for this SSO provider, e.g.
+                "oidc" or "saml".
+
+            remote_user_id: The unique identifier from the SSO provider.
+
+            attributes: the user attributes returned by the user mapping provider.
+
+            client_redirect_url: The redirect URL passed in by the client, which we
+                will eventually redirect back to.
+
+            extra_login_attributes: An optional dictionary of extra
+                attributes to be provided to the client in the login response.
+
+        Raises:
+            RedirectException
+        """
+        session_id = random_string(16)
+        now = self._clock.time_msec()
+        session = UsernameMappingSession(
+            auth_provider_id=auth_provider_id,
+            remote_user_id=remote_user_id,
+            display_name=attributes.display_name,
+            emails=attributes.emails,
+            client_redirect_url=client_redirect_url,
+            expiry_time_ms=now + self._MAPPING_SESSION_VALIDITY_PERIOD_MS,
+            extra_login_attributes=extra_login_attributes,
+        )
+
+        self._username_mapping_sessions[session_id] = session
+        logger.info("Recorded registration session id %s", session_id)
+
+        # Set the cookie and redirect to the username picker
+        e = RedirectException(b"/_synapse/client/pick_username")
+        e.cookies.append(
+            b"%s=%s; path=/"
+            % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii"))
+        )
+        raise e
+
     async def _register_mapped_user(
         self,
         attributes: UserAttributes,
@@ -269,9 +373,38 @@ class SsoHandler:
         user_agent: str,
         ip_address: str,
     ) -> str:
+        """Register a new SSO user.
+
+        This is called once we have successfully mapped the remote user id onto a local
+        user id, one way or another.
+
+        Args:
+             attributes: user attributes returned by the user mapping provider,
+                including a non-empty localpart.
+
+            auth_provider_id: A unique identifier for this SSO provider, e.g.
+                "oidc" or "saml".
+
+            remote_user_id: The unique identifier from the SSO provider.
+
+            user_agent: The user-agent in the HTTP request (used for potential
+                shadow-banning.)
+
+            ip_address: The IP address of the requester (used for potential
+                shadow-banning.)
+
+        Raises:
+            a MappingException if the localpart is invalid.
+
+            a SynapseError with code 400 and errcode Codes.USER_IN_USE if the localpart
+            is already taken.
+        """
+
         # Since the localpart is provided via a potentially untrusted module,
         # ensure the MXID is valid before registering.
-        if contains_invalid_mxid_characters(attributes.localpart):
+        if not attributes.localpart or contains_invalid_mxid_characters(
+            attributes.localpart
+        ):
             raise MappingException("localpart is invalid: %s" % (attributes.localpart,))
 
         logger.debug("Mapped SSO user to local part %s", attributes.localpart)
@@ -326,3 +459,108 @@ class SsoHandler:
         await self._auth_handler.complete_sso_ui_auth(
             user_id, ui_auth_session_id, request
         )
+
+    async def check_username_availability(
+        self, localpart: str, session_id: str,
+    ) -> bool:
+        """Handle an "is username available" callback check
+
+        Args:
+            localpart: desired localpart
+            session_id: the session id for the username picker
+        Returns:
+            True if the username is available
+        Raises:
+            SynapseError if the localpart is invalid or the session is unknown
+        """
+
+        # make sure that there is a valid mapping session, to stop people dictionary-
+        # scanning for accounts
+
+        self._expire_old_sessions()
+        session = self._username_mapping_sessions.get(session_id)
+        if not session:
+            logger.info("Couldn't find session id %s", session_id)
+            raise SynapseError(400, "unknown session")
+
+        logger.info(
+            "[session %s] Checking for availability of username %s",
+            session_id,
+            localpart,
+        )
+
+        if contains_invalid_mxid_characters(localpart):
+            raise SynapseError(400, "localpart is invalid: %s" % (localpart,))
+        user_id = UserID(localpart, self._server_name).to_string()
+        user_infos = await self._store.get_users_by_id_case_insensitive(user_id)
+
+        logger.info("[session %s] users: %s", session_id, user_infos)
+        return not user_infos
+
+    async def handle_submit_username_request(
+        self, request: SynapseRequest, localpart: str, session_id: str
+    ) -> None:
+        """Handle a request to the username-picker 'submit' endpoint
+
+        Will serve an HTTP response to the request.
+
+        Args:
+            request: HTTP request
+            localpart: localpart requested by the user
+            session_id: ID of the username mapping session, extracted from a cookie
+        """
+        self._expire_old_sessions()
+        session = self._username_mapping_sessions.get(session_id)
+        if not session:
+            logger.info("Couldn't find session id %s", session_id)
+            raise SynapseError(400, "unknown session")
+
+        logger.info("[session %s] Registering localpart %s", session_id, localpart)
+
+        attributes = UserAttributes(
+            localpart=localpart,
+            display_name=session.display_name,
+            emails=session.emails,
+        )
+
+        # the following will raise a 400 error if the username has been taken in the
+        # meantime.
+        user_id = await self._register_mapped_user(
+            attributes,
+            session.auth_provider_id,
+            session.remote_user_id,
+            request.get_user_agent(""),
+            request.getClientIP(),
+        )
+
+        logger.info("[session %s] Registered userid %s", session_id, user_id)
+
+        # delete the mapping session and the cookie
+        del self._username_mapping_sessions[session_id]
+
+        # delete the cookie
+        request.addCookie(
+            USERNAME_MAPPING_SESSION_COOKIE_NAME,
+            b"",
+            expires=b"Thu, 01 Jan 1970 00:00:00 GMT",
+            path=b"/",
+        )
+
+        await self._auth_handler.complete_sso_login(
+            user_id,
+            request,
+            session.client_redirect_url,
+            session.extra_login_attributes,
+        )
+
+    def _expire_old_sessions(self):
+        to_expire = []
+        now = int(self._clock.time_msec())
+
+        for session_id, session in self._username_mapping_sessions.items():
+            if session.expiry_time_ms <= now:
+                to_expire.append(session_id)
+
+        for session_id in to_expire:
+            logger.info("Expiring mapping session %s", session_id)
+            del self._username_mapping_sessions[session_id]