diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 4719712259..1a886cbbbf 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -34,10 +34,6 @@ from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.util.msisdn import phone_number_to_msisdn
-import saml2
-from saml2.client import Saml2Client
-
-
logger = logging.getLogger(__name__)
@@ -378,28 +374,49 @@ class LoginRestServlet(RestServlet):
defer.returnValue(result)
-class CasRedirectServlet(RestServlet):
+class BaseSsoRedirectServlet(RestServlet):
+ """Common base class for /login/sso/redirect impls"""
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
+ def on_GET(self, request):
+ args = request.args
+ if b"redirectUrl" not in args:
+ return 400, "Redirect URL not specified for SSO auth"
+ client_redirect_url = args[b"redirectUrl"][0]
+ sso_url = self.get_sso_url(client_redirect_url)
+ request.redirect(sso_url)
+ finish_request(request)
+
+ def get_sso_url(self, client_redirect_url):
+ """Get the URL to redirect to, to perform SSO auth
+
+ Args:
+ client_redirect_url (bytes): the URL that we should redirect the
+ client to when everything is done
+
+ Returns:
+ bytes: URL to redirect to
+ """
+ # to be implemented by subclasses
+ raise NotImplementedError()
+
+
+class CasRedirectServlet(RestServlet):
def __init__(self, hs):
super(CasRedirectServlet, self).__init__()
self.cas_server_url = hs.config.cas_server_url.encode('ascii')
self.cas_service_url = hs.config.cas_service_url.encode('ascii')
- def on_GET(self, request):
- args = request.args
- if b"redirectUrl" not in args:
- return (400, "Redirect URL not specified for CAS auth")
+ def get_sso_url(self, client_redirect_url):
client_redirect_url_param = urllib.parse.urlencode({
- b"redirectUrl": args[b"redirectUrl"][0]
+ b"redirectUrl": client_redirect_url
}).encode('ascii')
hs_redirect_url = (self.cas_service_url +
b"/_matrix/client/r0/login/cas/ticket")
service_param = urllib.parse.urlencode({
b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)
}).encode('ascii')
- request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param))
- finish_request(request)
+ return b"%s/login?%s" % (self.cas_server_url, service_param)
class CasTicketServlet(RestServlet):
@@ -482,41 +499,23 @@ class CasTicketServlet(RestServlet):
return user, attributes
-class SSORedirectServlet(RestServlet):
+class SAMLRedirectServlet(BaseSsoRedirectServlet):
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
def __init__(self, hs):
- super(SSORedirectServlet, self).__init__()
- self.saml2_sp_config = hs.config.saml2_sp_config
-
- def on_GET(self, request):
- args = request.args
-
- saml_client = Saml2Client(self.saml2_sp_config)
- reqid, info = saml_client.prepare_for_authenticate()
+ self._saml_client = hs.get_saml_client()
- redirect_url = None
+ def get_sso_url(self, client_redirect_url):
+ reqid, info = self._saml_client.prepare_for_authenticate(
+ relay_state=client_redirect_url,
+ )
- # Select the IdP URL to send the AuthN request to
for key, value in info['headers']:
- if key is 'Location':
- redirect_url = value
-
- if redirect_url is None:
- raise LoginError(401, "Unsuccessful SSO SAML2 redirect url response",
- errcode=Codes.UNAUTHORIZED)
-
- relay_state = "/_matrix/client/r0/login"
- if b"redirectUrl" in args:
- relay_state = args[b"redirectUrl"][0]
+ if key == 'Location':
+ return value
- url_parts = list(urllib.parse.urlparse(redirect_url))
- query = dict(urllib.parse.parse_qsl(url_parts[4]))
- query.update({"RelayState": relay_state})
- url_parts[4] = urllib.parse.urlencode(query)
-
- request.redirect(urllib.parse.urlunparse(url_parts))
- finish_request(request)
+ # this shouldn't happen!
+ raise Exception("prepare_for_authenticate didn't return a Location header")
class SSOAuthHandler(object):
@@ -594,5 +593,5 @@ def register_servlets(hs, http_server):
if hs.config.cas_enabled:
CasRedirectServlet(hs).register(http_server)
CasTicketServlet(hs).register(http_server)
- if hs.config.saml2_enabled:
- SSORedirectServlet(hs).register(http_server)
+ elif hs.config.saml2_enabled:
+ SAMLRedirectServlet(hs).register(http_server)
|