diff options
Diffstat (limited to 'synapse/rest/client/v1')
-rw-r--r-- | synapse/rest/client/v1/login.py | 89 |
1 files changed, 20 insertions, 69 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 5f4c6703db..ebc346105b 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -311,48 +311,31 @@ class LoginRestServlet(RestServlet): return result -class BaseSSORedirectServlet(RestServlet): - """Common base class for /login/sso/redirect impls""" - +class SsoRedirectServlet(RestServlet): PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) + def __init__(self, hs: "HomeServer"): + # make sure that the relevant handlers are instantiated, so that they + # register themselves with the main SSOHandler. + if hs.config.cas_enabled: + hs.get_cas_handler() + elif hs.config.saml2_enabled: + hs.get_saml_handler() + elif hs.config.oidc_enabled: + hs.get_oidc_handler() + self._sso_handler = hs.get_sso_handler() + async def on_GET(self, request: SynapseRequest): - args = request.args - if b"redirectUrl" not in args: - return 400, "Redirect URL not specified for SSO auth" - client_redirect_url = args[b"redirectUrl"][0] - sso_url = await self.get_sso_url(request, client_redirect_url) + client_redirect_url = parse_string( + request, "redirectUrl", required=True, encoding=None + ) + sso_url = await self._sso_handler.handle_redirect_request( + request, client_redirect_url + ) + logger.info("Redirecting to %s", sso_url) request.redirect(sso_url) finish_request(request) - async def get_sso_url( - self, request: SynapseRequest, client_redirect_url: bytes - ) -> bytes: - """Get the URL to redirect to, to perform SSO auth - - Args: - request: The client request to redirect. - client_redirect_url: the URL that we should redirect the - client to when everything is done - - Returns: - URL to redirect to - """ - # to be implemented by subclasses - raise NotImplementedError() - - -class CasRedirectServlet(BaseSSORedirectServlet): - def __init__(self, hs): - self._cas_handler = hs.get_cas_handler() - - async def get_sso_url( - self, request: SynapseRequest, client_redirect_url: bytes - ) -> bytes: - return self._cas_handler.get_redirect_url( - {"redirectUrl": client_redirect_url} - ).encode("ascii") - class CasTicketServlet(RestServlet): PATTERNS = client_patterns("/login/cas/ticket", v1=True) @@ -379,40 +362,8 @@ class CasTicketServlet(RestServlet): ) -class SAMLRedirectServlet(BaseSSORedirectServlet): - PATTERNS = client_patterns("/login/sso/redirect", v1=True) - - def __init__(self, hs): - self._saml_handler = hs.get_saml_handler() - - async def get_sso_url( - self, request: SynapseRequest, client_redirect_url: bytes - ) -> bytes: - return self._saml_handler.handle_redirect_request(client_redirect_url) - - -class OIDCRedirectServlet(BaseSSORedirectServlet): - """Implementation for /login/sso/redirect for the OIDC login flow.""" - - PATTERNS = client_patterns("/login/sso/redirect", v1=True) - - def __init__(self, hs): - self._oidc_handler = hs.get_oidc_handler() - - async def get_sso_url( - self, request: SynapseRequest, client_redirect_url: bytes - ) -> bytes: - return await self._oidc_handler.handle_redirect_request( - request, client_redirect_url - ) - - def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) + SsoRedirectServlet(hs).register(http_server) if hs.config.cas_enabled: - CasRedirectServlet(hs).register(http_server) CasTicketServlet(hs).register(http_server) - elif hs.config.saml2_enabled: - SAMLRedirectServlet(hs).register(http_server) - elif hs.config.oidc_enabled: - OIDCRedirectServlet(hs).register(http_server) |