1 files changed, 25 insertions, 3 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 4de2f97d06..de7eca21f8 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -83,6 +83,7 @@ class LoginRestServlet(RestServlet):
self.jwt_algorithm = hs.config.jwt_algorithm
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
+ self.oidc_enabled = hs.config.oidc_enabled
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self.handlers = hs.get_handlers()
@@ -96,9 +97,7 @@ class LoginRestServlet(RestServlet):
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
- if self.saml2_enabled:
- flows.append({"type": LoginRestServlet.SSO_TYPE})
- flows.append({"type": LoginRestServlet.TOKEN_TYPE})
+
if self.cas_enabled:
flows.append({"type": LoginRestServlet.SSO_TYPE})
@@ -114,6 +113,11 @@ class LoginRestServlet(RestServlet):
# fall back to the fallback API if they don't understand one of the
# login flow types returned.
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
+ elif self.saml2_enabled:
+ flows.append({"type": LoginRestServlet.SSO_TYPE})
+ flows.append({"type": LoginRestServlet.TOKEN_TYPE})
+ elif self.oidc_enabled:
+ flows.append({"type": LoginRestServlet.SSO_TYPE})
flows.extend(
({"type": t} for t in self.auth_handler.get_supported_login_types())
@@ -465,6 +469,22 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
return self._saml_handler.handle_redirect_request(client_redirect_url)
+class OIDCRedirectServlet(RestServlet):
+ """Implementation for /login/sso/redirect for the OIDC login flow."""
+
+ PATTERNS = client_patterns("/login/sso/redirect", v1=True)
+
+ def __init__(self, hs):
+ self._oidc_handler = hs.get_oidc_handler()
+
+ async def on_GET(self, request):
+ args = request.args
+ if b"redirectUrl" not in args:
+ return 400, "Redirect URL not specified for SSO auth"
+ client_redirect_url = args[b"redirectUrl"][0]
+ await self._oidc_handler.handle_redirect_request(request, client_redirect_url)
+
+
def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
if hs.config.cas_enabled:
@@ -472,3 +492,5 @@ def register_servlets(hs, http_server):
CasTicketServlet(hs).register(http_server)
elif hs.config.saml2_enabled:
SAMLRedirectServlet(hs).register(http_server)
+ elif hs.config.oidc_enabled:
+ OIDCRedirectServlet(hs).register(http_server)
|