diff options
Diffstat (limited to 'synapse/rest')
-rw-r--r-- | synapse/rest/client/v1/login.py | 54 | ||||
-rw-r--r-- | synapse/rest/saml2/response_resource.py | 37 |
2 files changed, 48 insertions, 43 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index ede6bc8b1e..3051b2171b 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -86,6 +86,7 @@ class LoginRestServlet(RestServlet): self.jwt_enabled = hs.config.jwt_enabled self.jwt_secret = hs.config.jwt_secret self.jwt_algorithm = hs.config.jwt_algorithm + self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() @@ -97,6 +98,9 @@ class LoginRestServlet(RestServlet): flows = [] if self.jwt_enabled: flows.append({"type": LoginRestServlet.JWT_TYPE}) + if self.saml2_enabled: + flows.append({"type": LoginRestServlet.SSO_TYPE}) + flows.append({"type": LoginRestServlet.TOKEN_TYPE}) if self.cas_enabled: flows.append({"type": LoginRestServlet.SSO_TYPE}) @@ -354,27 +358,49 @@ class LoginRestServlet(RestServlet): defer.returnValue(result) -class CasRedirectServlet(RestServlet): +class BaseSSORedirectServlet(RestServlet): + """Common base class for /login/sso/redirect impls""" + PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) + def on_GET(self, request): + args = request.args + if b"redirectUrl" not in args: + return 400, "Redirect URL not specified for SSO auth" + client_redirect_url = args[b"redirectUrl"][0] + sso_url = self.get_sso_url(client_redirect_url) + request.redirect(sso_url) + finish_request(request) + + def get_sso_url(self, client_redirect_url): + """Get the URL to redirect to, to perform SSO auth + + Args: + client_redirect_url (bytes): the URL that we should redirect the + client to when everything is done + + Returns: + bytes: URL to redirect to + """ + # to be implemented by subclasses + raise NotImplementedError() + + +class CasRedirectServlet(BaseSSORedirectServlet): def __init__(self, hs): super(CasRedirectServlet, self).__init__() self.cas_server_url = hs.config.cas_server_url.encode("ascii") self.cas_service_url = hs.config.cas_service_url.encode("ascii") - def on_GET(self, request): - args = request.args - if b"redirectUrl" not in args: - return (400, "Redirect URL not specified for CAS auth") + def get_sso_url(self, client_redirect_url): client_redirect_url_param = urllib.parse.urlencode( - {b"redirectUrl": args[b"redirectUrl"][0]} + {b"redirectUrl": client_redirect_url} ).encode("ascii") hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket" service_param = urllib.parse.urlencode( {b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)} ).encode("ascii") - request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param)) - finish_request(request) + return b"%s/login?%s" % (self.cas_server_url, service_param) class CasTicketServlet(RestServlet): @@ -457,6 +483,16 @@ class CasTicketServlet(RestServlet): return user, attributes +class SAMLRedirectServlet(BaseSSORedirectServlet): + PATTERNS = client_patterns("/login/sso/redirect", v1=True) + + def __init__(self, hs): + self._saml_handler = hs.get_saml_handler() + + def get_sso_url(self, client_redirect_url): + return self._saml_handler.handle_redirect_request(client_redirect_url) + + class SSOAuthHandler(object): """ Utility class for Resources and Servlets which handle the response from a SSO @@ -532,3 +568,5 @@ def register_servlets(hs, http_server): if hs.config.cas_enabled: CasRedirectServlet(hs).register(http_server) CasTicketServlet(hs).register(http_server) + elif hs.config.saml2_enabled: + SAMLRedirectServlet(hs).register(http_server) diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py index 939c87306c..69ecc5e4b4 100644 --- a/synapse/rest/saml2/response_resource.py +++ b/synapse/rest/saml2/response_resource.py @@ -13,17 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import saml2 -from saml2.client import Saml2Client - -from synapse.api.errors import CodeMessageException from synapse.http.server import DirectServeResource, wrap_html_request_handler -from synapse.http.servlet import parse_string -from synapse.rest.client.v1.login import SSOAuthHandler - -logger = logging.getLogger(__name__) class SAML2ResponseResource(DirectServeResource): @@ -33,32 +24,8 @@ class SAML2ResponseResource(DirectServeResource): def __init__(self, hs): super().__init__() - - self._saml_client = Saml2Client(hs.config.saml2_sp_config) - self._sso_auth_handler = SSOAuthHandler(hs) + self._saml_handler = hs.get_saml_handler() @wrap_html_request_handler async def _async_render_POST(self, request): - resp_bytes = parse_string(request, "SAMLResponse", required=True) - relay_state = parse_string(request, "RelayState", required=True) - - try: - saml2_auth = self._saml_client.parse_authn_request_response( - resp_bytes, saml2.BINDING_HTTP_POST - ) - except Exception as e: - logger.warning("Exception parsing SAML2 response", exc_info=1) - raise CodeMessageException(400, "Unable to parse SAML2 response: %s" % (e,)) - - if saml2_auth.not_signed: - raise CodeMessageException(400, "SAML2 response was not signed") - - if "uid" not in saml2_auth.ava: - raise CodeMessageException(400, "uid not in SAML2 response") - - username = saml2_auth.ava["uid"][0] - - displayName = saml2_auth.ava.get("displayName", [None])[0] - return self._sso_auth_handler.on_successful_auth( - username, request, relay_state, user_display_name=displayName - ) + return await self._saml_handler.handle_saml_response(request) |