diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index de7eca21f8..d89b2e5532 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -401,19 +401,22 @@ class BaseSSORedirectServlet(RestServlet):
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
- def on_GET(self, request: SynapseRequest):
+ async def on_GET(self, request: SynapseRequest):
args = request.args
if b"redirectUrl" not in args:
return 400, "Redirect URL not specified for SSO auth"
client_redirect_url = args[b"redirectUrl"][0]
- sso_url = self.get_sso_url(client_redirect_url)
+ sso_url = await self.get_sso_url(request, client_redirect_url)
request.redirect(sso_url)
finish_request(request)
- def get_sso_url(self, client_redirect_url: bytes) -> bytes:
+ async def get_sso_url(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> bytes:
"""Get the URL to redirect to, to perform SSO auth
Args:
+ request: The client request to redirect.
client_redirect_url: the URL that we should redirect the
client to when everything is done
@@ -428,7 +431,9 @@ class CasRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
self._cas_handler = hs.get_cas_handler()
- def get_sso_url(self, client_redirect_url: bytes) -> bytes:
+ async def get_sso_url(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> bytes:
return self._cas_handler.get_redirect_url(
{"redirectUrl": client_redirect_url}
).encode("ascii")
@@ -465,11 +470,13 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
self._saml_handler = hs.get_saml_handler()
- def get_sso_url(self, client_redirect_url: bytes) -> bytes:
+ async def get_sso_url(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> bytes:
return self._saml_handler.handle_redirect_request(client_redirect_url)
-class OIDCRedirectServlet(RestServlet):
+class OIDCRedirectServlet(BaseSSORedirectServlet):
"""Implementation for /login/sso/redirect for the OIDC login flow."""
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
@@ -477,12 +484,12 @@ class OIDCRedirectServlet(RestServlet):
def __init__(self, hs):
self._oidc_handler = hs.get_oidc_handler()
- async def on_GET(self, request):
- args = request.args
- if b"redirectUrl" not in args:
- return 400, "Redirect URL not specified for SSO auth"
- client_redirect_url = args[b"redirectUrl"][0]
- await self._oidc_handler.handle_redirect_request(request, client_redirect_url)
+ async def get_sso_url(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> bytes:
+ return await self._oidc_handler.handle_redirect_request(
+ request, client_redirect_url
+ )
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 24dd3d3e96..7bca1326d5 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -131,14 +131,19 @@ class AuthRestServlet(RestServlet):
self.registration_handler = hs.get_registration_handler()
# SSO configuration.
- self._saml_enabled = hs.config.saml2_enabled
- if self._saml_enabled:
- self._saml_handler = hs.get_saml_handler()
self._cas_enabled = hs.config.cas_enabled
if self._cas_enabled:
self._cas_handler = hs.get_cas_handler()
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url
+ self._saml_enabled = hs.config.saml2_enabled
+ if self._saml_enabled:
+ self._saml_handler = hs.get_saml_handler()
+ self._oidc_enabled = hs.config.oidc_enabled
+ if self._oidc_enabled:
+ self._oidc_handler = hs.get_oidc_handler()
+ self._cas_server_url = hs.config.cas_server_url
+ self._cas_service_url = hs.config.cas_service_url
async def on_GET(self, request, stagetype):
session = parse_string(request, "session")
@@ -172,11 +177,17 @@ class AuthRestServlet(RestServlet):
)
elif self._saml_enabled:
- client_redirect_url = ""
+ client_redirect_url = b""
sso_redirect_url = self._saml_handler.handle_redirect_request(
client_redirect_url, session
)
+ elif self._oidc_enabled:
+ client_redirect_url = b""
+ sso_redirect_url = await self._oidc_handler.handle_redirect_request(
+ request, client_redirect_url, session
+ )
+
else:
raise SynapseError(400, "Homeserver not configured for SSO.")
|