diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 029039c162..ae9bbba619 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -33,6 +33,9 @@ from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.util.msisdn import phone_number_to_msisdn
+import saml2
+from saml2.client import Saml2Client
+
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__)
@@ -93,6 +96,7 @@ class LoginRestServlet(ClientV1RestServlet):
self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm
+ self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
@@ -104,6 +108,9 @@ class LoginRestServlet(ClientV1RestServlet):
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
+ if self.saml2_enabled:
+ flows.append({"type": LoginRestServlet.SSO_TYPE})
+ flows.append({"type": LoginRestServlet.TOKEN_TYPE})
if self.cas_enabled:
flows.append({"type": LoginRestServlet.SSO_TYPE})
@@ -474,6 +481,43 @@ class CasTicketServlet(ClientV1RestServlet):
return user, attributes
+class SSORedirectServlet(RestServlet):
+ PATTERNS = client_path_patterns("/login/sso/redirect")
+
+ def __init__(self, hs):
+ super(SSORedirectServlet, self).__init__()
+ self.saml2_sp_config = hs.config.saml2_sp_config
+
+ def on_GET(self, request):
+ args = request.args
+
+ saml_client = Saml2Client(self.saml2_sp_config)
+ reqid, info = saml_client.prepare_for_authenticate()
+
+ redirect_url = None
+
+ # Select the IdP URL to send the AuthN request to
+ for key, value in info['headers']:
+ if key is 'Location':
+ redirect_url = value
+
+ if redirect_url is None:
+ raise LoginError(401, "Unsuccessful SSO SAML2 redirect url response",
+ errcode=Codes.UNAUTHORIZED)
+
+ relay_state = "/_matrix/client/r0/login"
+ if b"redirectUrl" in args:
+ relay_state = args[b"redirectUrl"][0]
+
+ url_parts = list(urllib.parse.urlparse(redirect_url))
+ query = dict(urllib.parse.parse_qsl(url_parts[4]))
+ query.update({"RelayState": relay_state})
+ url_parts[4] = urllib.parse.urlencode(query)
+
+ request.redirect(urllib.parse.urlunparse(url_parts))
+ finish_request(request)
+
+
class SSOAuthHandler(object):
"""
Utility class for Resources and Servlets which handle the response from a SSO
@@ -549,3 +593,5 @@ def register_servlets(hs, http_server):
if hs.config.cas_enabled:
CasRedirectServlet(hs).register(http_server)
CasTicketServlet(hs).register(http_server)
+ if hs.config.saml2_enabled:
+ SSORedirectServlet(hs).register(http_server)
|