summary refs log tree commit diff
path: root/synapse/rest/client/v1/login.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/v1/login.py')
-rw-r--r--synapse/rest/client/v1/login.py600
1 files changed, 0 insertions, 600 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
deleted file mode 100644
index 11567bf32c..0000000000
--- a/synapse/rest/client/v1/login.py
+++ /dev/null
@@ -1,600 +0,0 @@
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-import re
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
-
-from typing_extensions import TypedDict
-
-from synapse.api.errors import Codes, LoginError, SynapseError
-from synapse.api.ratelimiting import Ratelimiter
-from synapse.api.urls import CLIENT_API_PREFIX
-from synapse.appservice import ApplicationService
-from synapse.handlers.sso import SsoIdentityProvider
-from synapse.http import get_request_uri
-from synapse.http.server import HttpServer, finish_request
-from synapse.http.servlet import (
-    RestServlet,
-    assert_params_in_dict,
-    parse_boolean,
-    parse_bytes_from_args,
-    parse_json_object_from_request,
-    parse_string,
-)
-from synapse.http.site import SynapseRequest
-from synapse.rest.client.v2_alpha._base import client_patterns
-from synapse.rest.well_known import WellKnownBuilder
-from synapse.types import JsonDict, UserID
-
-if TYPE_CHECKING:
-    from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-
-class LoginResponse(TypedDict, total=False):
-    user_id: str
-    access_token: str
-    home_server: str
-    expires_in_ms: Optional[int]
-    refresh_token: Optional[str]
-    device_id: str
-    well_known: Optional[Dict[str, Any]]
-
-
-class LoginRestServlet(RestServlet):
-    PATTERNS = client_patterns("/login$", v1=True)
-    CAS_TYPE = "m.login.cas"
-    SSO_TYPE = "m.login.sso"
-    TOKEN_TYPE = "m.login.token"
-    JWT_TYPE = "org.matrix.login.jwt"
-    JWT_TYPE_DEPRECATED = "m.login.jwt"
-    APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
-    REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token"
-
-    def __init__(self, hs: "HomeServer"):
-        super().__init__()
-        self.hs = hs
-
-        # JWT configuration variables.
-        self.jwt_enabled = hs.config.jwt_enabled
-        self.jwt_secret = hs.config.jwt_secret
-        self.jwt_algorithm = hs.config.jwt_algorithm
-        self.jwt_issuer = hs.config.jwt_issuer
-        self.jwt_audiences = hs.config.jwt_audiences
-
-        # SSO configuration.
-        self.saml2_enabled = hs.config.saml2_enabled
-        self.cas_enabled = hs.config.cas_enabled
-        self.oidc_enabled = hs.config.oidc_enabled
-        self._msc2858_enabled = hs.config.experimental.msc2858_enabled
-        self._msc2918_enabled = hs.config.access_token_lifetime is not None
-
-        self.auth = hs.get_auth()
-
-        self.clock = hs.get_clock()
-
-        self.auth_handler = self.hs.get_auth_handler()
-        self.registration_handler = hs.get_registration_handler()
-        self._sso_handler = hs.get_sso_handler()
-
-        self._well_known_builder = WellKnownBuilder(hs)
-        self._address_ratelimiter = Ratelimiter(
-            store=hs.get_datastore(),
-            clock=hs.get_clock(),
-            rate_hz=self.hs.config.rc_login_address.per_second,
-            burst_count=self.hs.config.rc_login_address.burst_count,
-        )
-        self._account_ratelimiter = Ratelimiter(
-            store=hs.get_datastore(),
-            clock=hs.get_clock(),
-            rate_hz=self.hs.config.rc_login_account.per_second,
-            burst_count=self.hs.config.rc_login_account.burst_count,
-        )
-
-    def on_GET(self, request: SynapseRequest):
-        flows = []
-        if self.jwt_enabled:
-            flows.append({"type": LoginRestServlet.JWT_TYPE})
-            flows.append({"type": LoginRestServlet.JWT_TYPE_DEPRECATED})
-
-        if self.cas_enabled:
-            # we advertise CAS for backwards compat, though MSC1721 renamed it
-            # to SSO.
-            flows.append({"type": LoginRestServlet.CAS_TYPE})
-
-        if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
-            sso_flow: JsonDict = {
-                "type": LoginRestServlet.SSO_TYPE,
-                "identity_providers": [
-                    _get_auth_flow_dict_for_idp(
-                        idp,
-                    )
-                    for idp in self._sso_handler.get_identity_providers().values()
-                ],
-            }
-
-            if self._msc2858_enabled:
-                # backwards-compatibility support for clients which don't
-                # support the stable API yet
-                sso_flow["org.matrix.msc2858.identity_providers"] = [
-                    _get_auth_flow_dict_for_idp(idp, use_unstable_brands=True)
-                    for idp in self._sso_handler.get_identity_providers().values()
-                ]
-
-            flows.append(sso_flow)
-
-            # While it's valid for us to advertise this login type generally,
-            # synapse currently only gives out these tokens as part of the
-            # SSO login flow.
-            # Generally we don't want to advertise login flows that clients
-            # don't know how to implement, since they (currently) will always
-            # fall back to the fallback API if they don't understand one of the
-            # login flow types returned.
-            flows.append({"type": LoginRestServlet.TOKEN_TYPE})
-
-        flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types())
-
-        flows.append({"type": LoginRestServlet.APPSERVICE_TYPE})
-
-        return 200, {"flows": flows}
-
-    async def on_POST(self, request: SynapseRequest):
-        login_submission = parse_json_object_from_request(request)
-
-        if self._msc2918_enabled:
-            # Check if this login should also issue a refresh token, as per
-            # MSC2918
-            should_issue_refresh_token = parse_boolean(
-                request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False
-            )
-        else:
-            should_issue_refresh_token = False
-
-        try:
-            if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
-                appservice = self.auth.get_appservice_by_req(request)
-
-                if appservice.is_rate_limited():
-                    await self._address_ratelimiter.ratelimit(
-                        None, request.getClientIP()
-                    )
-
-                result = await self._do_appservice_login(
-                    login_submission,
-                    appservice,
-                    should_issue_refresh_token=should_issue_refresh_token,
-                )
-            elif self.jwt_enabled and (
-                login_submission["type"] == LoginRestServlet.JWT_TYPE
-                or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
-            ):
-                await self._address_ratelimiter.ratelimit(None, request.getClientIP())
-                result = await self._do_jwt_login(
-                    login_submission,
-                    should_issue_refresh_token=should_issue_refresh_token,
-                )
-            elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
-                await self._address_ratelimiter.ratelimit(None, request.getClientIP())
-                result = await self._do_token_login(
-                    login_submission,
-                    should_issue_refresh_token=should_issue_refresh_token,
-                )
-            else:
-                await self._address_ratelimiter.ratelimit(None, request.getClientIP())
-                result = await self._do_other_login(
-                    login_submission,
-                    should_issue_refresh_token=should_issue_refresh_token,
-                )
-        except KeyError:
-            raise SynapseError(400, "Missing JSON keys.")
-
-        well_known_data = self._well_known_builder.get_well_known()
-        if well_known_data:
-            result["well_known"] = well_known_data
-        return 200, result
-
-    async def _do_appservice_login(
-        self,
-        login_submission: JsonDict,
-        appservice: ApplicationService,
-        should_issue_refresh_token: bool = False,
-    ):
-        identifier = login_submission.get("identifier")
-        logger.info("Got appservice login request with identifier: %r", identifier)
-
-        if not isinstance(identifier, dict):
-            raise SynapseError(
-                400, "Invalid identifier in login submission", Codes.INVALID_PARAM
-            )
-
-        # this login flow only supports identifiers of type "m.id.user".
-        if identifier.get("type") != "m.id.user":
-            raise SynapseError(
-                400, "Unknown login identifier type", Codes.INVALID_PARAM
-            )
-
-        user = identifier.get("user")
-        if not isinstance(user, str):
-            raise SynapseError(400, "Invalid user in identifier", Codes.INVALID_PARAM)
-
-        if user.startswith("@"):
-            qualified_user_id = user
-        else:
-            qualified_user_id = UserID(user, self.hs.hostname).to_string()
-
-        if not appservice.is_interested_in_user(qualified_user_id):
-            raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
-
-        return await self._complete_login(
-            qualified_user_id,
-            login_submission,
-            ratelimit=appservice.is_rate_limited(),
-            should_issue_refresh_token=should_issue_refresh_token,
-        )
-
-    async def _do_other_login(
-        self, login_submission: JsonDict, should_issue_refresh_token: bool = False
-    ) -> LoginResponse:
-        """Handle non-token/saml/jwt logins
-
-        Args:
-            login_submission:
-            should_issue_refresh_token: True if this login should issue
-                a refresh token alongside the access token.
-
-        Returns:
-            HTTP response
-        """
-        # Log the request we got, but only certain fields to minimise the chance of
-        # logging someone's password (even if they accidentally put it in the wrong
-        # field)
-        logger.info(
-            "Got login request with identifier: %r, medium: %r, address: %r, user: %r",
-            login_submission.get("identifier"),
-            login_submission.get("medium"),
-            login_submission.get("address"),
-            login_submission.get("user"),
-        )
-        canonical_user_id, callback = await self.auth_handler.validate_login(
-            login_submission, ratelimit=True
-        )
-        result = await self._complete_login(
-            canonical_user_id,
-            login_submission,
-            callback,
-            should_issue_refresh_token=should_issue_refresh_token,
-        )
-        return result
-
-    async def _complete_login(
-        self,
-        user_id: str,
-        login_submission: JsonDict,
-        callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None,
-        create_non_existent_users: bool = False,
-        ratelimit: bool = True,
-        auth_provider_id: Optional[str] = None,
-        should_issue_refresh_token: bool = False,
-    ) -> LoginResponse:
-        """Called when we've successfully authed the user and now need to
-        actually login them in (e.g. create devices). This gets called on
-        all successful logins.
-
-        Applies the ratelimiting for successful login attempts against an
-        account.
-
-        Args:
-            user_id: ID of the user to register.
-            login_submission: Dictionary of login information.
-            callback: Callback function to run after login.
-            create_non_existent_users: Whether to create the user if they don't
-                exist. Defaults to False.
-            ratelimit: Whether to ratelimit the login request.
-            auth_provider_id: The SSO IdP the user used, if any (just used for the
-                prometheus metrics).
-            should_issue_refresh_token: True if this login should issue
-                a refresh token alongside the access token.
-
-        Returns:
-            result: Dictionary of account information after successful login.
-        """
-
-        # Before we actually log them in we check if they've already logged in
-        # too often. This happens here rather than before as we don't
-        # necessarily know the user before now.
-        if ratelimit:
-            await self._account_ratelimiter.ratelimit(None, user_id.lower())
-
-        if create_non_existent_users:
-            canonical_uid = await self.auth_handler.check_user_exists(user_id)
-            if not canonical_uid:
-                canonical_uid = await self.registration_handler.register_user(
-                    localpart=UserID.from_string(user_id).localpart
-                )
-            user_id = canonical_uid
-
-        device_id = login_submission.get("device_id")
-        initial_display_name = login_submission.get("initial_device_display_name")
-        (
-            device_id,
-            access_token,
-            valid_until_ms,
-            refresh_token,
-        ) = await self.registration_handler.register_device(
-            user_id,
-            device_id,
-            initial_display_name,
-            auth_provider_id=auth_provider_id,
-            should_issue_refresh_token=should_issue_refresh_token,
-        )
-
-        result = LoginResponse(
-            user_id=user_id,
-            access_token=access_token,
-            home_server=self.hs.hostname,
-            device_id=device_id,
-        )
-
-        if valid_until_ms is not None:
-            expires_in_ms = valid_until_ms - self.clock.time_msec()
-            result["expires_in_ms"] = expires_in_ms
-
-        if refresh_token is not None:
-            result["refresh_token"] = refresh_token
-
-        if callback is not None:
-            await callback(result)
-
-        return result
-
-    async def _do_token_login(
-        self, login_submission: JsonDict, should_issue_refresh_token: bool = False
-    ) -> LoginResponse:
-        """
-        Handle the final stage of SSO login.
-
-        Args:
-            login_submission: The JSON request body.
-            should_issue_refresh_token: True if this login should issue
-                a refresh token alongside the access token.
-
-        Returns:
-            The body of the JSON response.
-        """
-        token = login_submission["token"]
-        auth_handler = self.auth_handler
-        res = await auth_handler.validate_short_term_login_token(token)
-
-        return await self._complete_login(
-            res.user_id,
-            login_submission,
-            self.auth_handler._sso_login_callback,
-            auth_provider_id=res.auth_provider_id,
-            should_issue_refresh_token=should_issue_refresh_token,
-        )
-
-    async def _do_jwt_login(
-        self, login_submission: JsonDict, should_issue_refresh_token: bool = False
-    ) -> LoginResponse:
-        token = login_submission.get("token", None)
-        if token is None:
-            raise LoginError(
-                403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN
-            )
-
-        import jwt
-
-        try:
-            payload = jwt.decode(
-                token,
-                self.jwt_secret,
-                algorithms=[self.jwt_algorithm],
-                issuer=self.jwt_issuer,
-                audience=self.jwt_audiences,
-            )
-        except jwt.PyJWTError as e:
-            # A JWT error occurred, return some info back to the client.
-            raise LoginError(
-                403,
-                "JWT validation failed: %s" % (str(e),),
-                errcode=Codes.FORBIDDEN,
-            )
-
-        user = payload.get("sub", None)
-        if user is None:
-            raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
-
-        user_id = UserID(user, self.hs.hostname).to_string()
-        result = await self._complete_login(
-            user_id,
-            login_submission,
-            create_non_existent_users=True,
-            should_issue_refresh_token=should_issue_refresh_token,
-        )
-        return result
-
-
-def _get_auth_flow_dict_for_idp(
-    idp: SsoIdentityProvider, use_unstable_brands: bool = False
-) -> JsonDict:
-    """Return an entry for the login flow dict
-
-    Returns an entry suitable for inclusion in "identity_providers" in the
-    response to GET /_matrix/client/r0/login
-
-    Args:
-        idp: the identity provider to describe
-        use_unstable_brands: whether we should use brand identifiers suitable
-           for the unstable API
-    """
-    e: JsonDict = {"id": idp.idp_id, "name": idp.idp_name}
-    if idp.idp_icon:
-        e["icon"] = idp.idp_icon
-    if idp.idp_brand:
-        e["brand"] = idp.idp_brand
-    # use the stable brand identifier if the unstable identifier isn't defined.
-    if use_unstable_brands and idp.unstable_idp_brand:
-        e["brand"] = idp.unstable_idp_brand
-    return e
-
-
-class RefreshTokenServlet(RestServlet):
-    PATTERNS = client_patterns(
-        "/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True
-    )
-
-    def __init__(self, hs: "HomeServer"):
-        self._auth_handler = hs.get_auth_handler()
-        self._clock = hs.get_clock()
-        self.access_token_lifetime = hs.config.access_token_lifetime
-
-    async def on_POST(
-        self,
-        request: SynapseRequest,
-    ):
-        refresh_submission = parse_json_object_from_request(request)
-
-        assert_params_in_dict(refresh_submission, ["refresh_token"])
-        token = refresh_submission["refresh_token"]
-        if not isinstance(token, str):
-            raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM)
-
-        valid_until_ms = self._clock.time_msec() + self.access_token_lifetime
-        access_token, refresh_token = await self._auth_handler.refresh_token(
-            token, valid_until_ms
-        )
-        expires_in_ms = valid_until_ms - self._clock.time_msec()
-        return (
-            200,
-            {
-                "access_token": access_token,
-                "refresh_token": refresh_token,
-                "expires_in_ms": expires_in_ms,
-            },
-        )
-
-
-class SsoRedirectServlet(RestServlet):
-    PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
-        re.compile(
-            "^"
-            + CLIENT_API_PREFIX
-            + "/r0/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$"
-        )
-    ]
-
-    def __init__(self, hs: "HomeServer"):
-        # make sure that the relevant handlers are instantiated, so that they
-        # register themselves with the main SSOHandler.
-        if hs.config.cas_enabled:
-            hs.get_cas_handler()
-        if hs.config.saml2_enabled:
-            hs.get_saml_handler()
-        if hs.config.oidc_enabled:
-            hs.get_oidc_handler()
-        self._sso_handler = hs.get_sso_handler()
-        self._msc2858_enabled = hs.config.experimental.msc2858_enabled
-        self._public_baseurl = hs.config.public_baseurl
-
-    def register(self, http_server: HttpServer) -> None:
-        super().register(http_server)
-        if self._msc2858_enabled:
-            # expose additional endpoint for MSC2858 support: backwards-compat support
-            # for clients which don't yet support the stable endpoints.
-            http_server.register_paths(
-                "GET",
-                client_patterns(
-                    "/org.matrix.msc2858/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$",
-                    releases=(),
-                    unstable=True,
-                ),
-                self.on_GET,
-                self.__class__.__name__,
-            )
-
-    async def on_GET(
-        self, request: SynapseRequest, idp_id: Optional[str] = None
-    ) -> None:
-        if not self._public_baseurl:
-            raise SynapseError(400, "SSO requires a valid public_baseurl")
-
-        # if this isn't the expected hostname, redirect to the right one, so that we
-        # get our cookies back.
-        requested_uri = get_request_uri(request)
-        baseurl_bytes = self._public_baseurl.encode("utf-8")
-        if not requested_uri.startswith(baseurl_bytes):
-            # swap out the incorrect base URL for the right one.
-            #
-            # The idea here is to redirect from
-            #    https://foo.bar/whatever/_matrix/...
-            # to
-            #    https://public.baseurl/_matrix/...
-            #
-            i = requested_uri.index(b"/_matrix")
-            new_uri = baseurl_bytes[:-1] + requested_uri[i:]
-            logger.info(
-                "Requested URI %s is not canonical: redirecting to %s",
-                requested_uri.decode("utf-8", errors="replace"),
-                new_uri.decode("utf-8", errors="replace"),
-            )
-            request.redirect(new_uri)
-            finish_request(request)
-            return
-
-        args: Dict[bytes, List[bytes]] = request.args  # type: ignore
-        client_redirect_url = parse_bytes_from_args(args, "redirectUrl", required=True)
-        sso_url = await self._sso_handler.handle_redirect_request(
-            request,
-            client_redirect_url,
-            idp_id,
-        )
-        logger.info("Redirecting to %s", sso_url)
-        request.redirect(sso_url)
-        finish_request(request)
-
-
-class CasTicketServlet(RestServlet):
-    PATTERNS = client_patterns("/login/cas/ticket", v1=True)
-
-    def __init__(self, hs):
-        super().__init__()
-        self._cas_handler = hs.get_cas_handler()
-
-    async def on_GET(self, request: SynapseRequest) -> None:
-        client_redirect_url = parse_string(request, "redirectUrl")
-        ticket = parse_string(request, "ticket", required=True)
-
-        # Maybe get a session ID (if this ticket is from user interactive
-        # authentication).
-        session = parse_string(request, "session")
-
-        # Either client_redirect_url or session must be provided.
-        if not client_redirect_url and not session:
-            message = "Missing string query parameter redirectUrl or session"
-            raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
-
-        await self._cas_handler.handle_ticket(
-            request, ticket, client_redirect_url, session
-        )
-
-
-def register_servlets(hs, http_server):
-    LoginRestServlet(hs).register(http_server)
-    if hs.config.access_token_lifetime is not None:
-        RefreshTokenServlet(hs).register(http_server)
-    SsoRedirectServlet(hs).register(http_server)
-    if hs.config.cas_enabled:
-        CasTicketServlet(hs).register(http_server)