summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2021-01-04 18:13:49 +0000
committerGitHub <noreply@github.com>2021-01-04 18:13:49 +0000
commitd2c616a41381c9e2d43b08d5f225b52042d94d23 (patch)
tree1ae1b9beda79b718f67d8eaa3067de9370403d2b
parentAdd type hints to the receipts and user directory handlers. (#8976) (diff)
downloadsynapse-d2c616a41381c9e2d43b08d5f225b52042d94d23.tar.xz
Combine the SSO Redirect Servlets (#9015)
* Implement CasHandler.handle_redirect_request

... to make it match OidcHandler and SamlHandler

* Clean up interface for OidcHandler.handle_redirect_request

Make it accept `client_redirect_url=None`.

* Clean up interface for `SamlHandler.handle_redirect_request`

... bring it into line with CAS and OIDC by making it take a Request parameter,
move the magic for `client_redirect_url` for UIA into the handler, and fix the
return type to be a `str` rather than a `bytes`.

* Define a common protocol for SSO auth provider impls

* Give SsoIdentityProvider an ID and register them

* Combine the SSO Redirect servlets

Now that the SsoHandler knows about the identity providers, we can combine the
various *RedirectServlets into a single implementation which delegates to the
right IdP.

* changelog
-rw-r--r--changelog.d/9015.feature1
-rw-r--r--synapse/handlers/cas_handler.py35
-rw-r--r--synapse/handlers/oidc_handler.py15
-rw-r--r--synapse/handlers/saml_handler.py25
-rw-r--r--synapse/handlers/sso.py86
-rw-r--r--synapse/rest/client/v1/login.py89
-rw-r--r--synapse/rest/client/v2_alpha/auth.py34
-rw-r--r--tests/rest/client/v1/test_login.py2
8 files changed, 174 insertions, 113 deletions
diff --git a/changelog.d/9015.feature b/changelog.d/9015.feature
new file mode 100644
index 0000000000..01a24dcf49
--- /dev/null
+++ b/changelog.d/9015.feature
@@ -0,0 +1 @@
+Add support for multiple SSO Identity Providers.
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index fca210a5a6..295974c521 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -75,10 +75,12 @@ class CasHandler:
         self._http_client = hs.get_proxied_http_client()
 
         # identifier for the external_ids table
-        self._auth_provider_id = "cas"
+        self.idp_id = "cas"
 
         self._sso_handler = hs.get_sso_handler()
 
+        self._sso_handler.register_identity_provider(self)
+
     def _build_service_param(self, args: Dict[str, str]) -> str:
         """
         Generates a value to use as the "service" parameter when redirecting or
@@ -105,7 +107,7 @@ class CasHandler:
         Args:
             ticket: The CAS ticket from the client.
             service_args: Additional arguments to include in the service URL.
-                Should be the same as those passed to `get_redirect_url`.
+                Should be the same as those passed to `handle_redirect_request`.
 
         Raises:
             CasError: If there's an error parsing the CAS response.
@@ -184,16 +186,31 @@ class CasHandler:
 
         return CasResponse(user, attributes)
 
-    def get_redirect_url(self, service_args: Dict[str, str]) -> str:
-        """
-        Generates a URL for the CAS server where the client should be redirected.
+    async def handle_redirect_request(
+        self,
+        request: SynapseRequest,
+        client_redirect_url: Optional[bytes],
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
+        """Generates a URL for the CAS server where the client should be redirected.
 
         Args:
-            service_args: Additional arguments to include in the final redirect URL.
+            request: the incoming HTTP request
+            client_redirect_url: the URL that we should redirect the
+                client to after login (or None for UI Auth).
+            ui_auth_session_id: The session ID of the ongoing UI Auth (or
+                None if this is a login).
 
         Returns:
-            The URL to redirect the client to.
+            URL to redirect to
         """
+
+        if ui_auth_session_id:
+            service_args = {"session": ui_auth_session_id}
+        else:
+            assert client_redirect_url
+            service_args = {"redirectUrl": client_redirect_url.decode("utf8")}
+
         args = urllib.parse.urlencode(
             {"service": self._build_service_param(service_args)}
         )
@@ -275,7 +292,7 @@ class CasHandler:
         # first check if we're doing a UIA
         if session:
             return await self._sso_handler.complete_sso_ui_auth_request(
-                self._auth_provider_id, cas_response.username, session, request,
+                self.idp_id, cas_response.username, session, request,
             )
 
         # otherwise, we're handling a login request.
@@ -375,7 +392,7 @@ class CasHandler:
             return None
 
         await self._sso_handler.complete_sso_login_request(
-            self._auth_provider_id,
+            self.idp_id,
             cas_response.username,
             request,
             client_redirect_url,
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 709f8dfc13..3e2b60eb7b 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -119,10 +119,12 @@ class OidcHandler(BaseHandler):
         self._macaroon_secret_key = hs.config.macaroon_secret_key
 
         # identifier for the external_ids table
-        self._auth_provider_id = "oidc"
+        self.idp_id = "oidc"
 
         self._sso_handler = hs.get_sso_handler()
 
+        self._sso_handler.register_identity_provider(self)
+
     def _validate_metadata(self):
         """Verifies the provider metadata.
 
@@ -475,7 +477,7 @@ class OidcHandler(BaseHandler):
     async def handle_redirect_request(
         self,
         request: SynapseRequest,
-        client_redirect_url: bytes,
+        client_redirect_url: Optional[bytes],
         ui_auth_session_id: Optional[str] = None,
     ) -> str:
         """Handle an incoming request to /login/sso/redirect
@@ -499,7 +501,7 @@ class OidcHandler(BaseHandler):
             request: the incoming request from the browser.
                 We'll respond to it with a redirect and a cookie.
             client_redirect_url: the URL that we should redirect the client to
-                when everything is done
+                when everything is done (or None for UI Auth)
             ui_auth_session_id: The session ID of the ongoing UI Auth (or
                 None if this is a login).
 
@@ -511,6 +513,9 @@ class OidcHandler(BaseHandler):
         state = generate_token()
         nonce = generate_token()
 
+        if not client_redirect_url:
+            client_redirect_url = b""
+
         cookie = self._generate_oidc_session_token(
             state=state,
             nonce=nonce,
@@ -682,7 +687,7 @@ class OidcHandler(BaseHandler):
                 return
 
             return await self._sso_handler.complete_sso_ui_auth_request(
-                self._auth_provider_id, remote_user_id, ui_auth_session_id, request
+                self.idp_id, remote_user_id, ui_auth_session_id, request
             )
 
         # otherwise, it's a login
@@ -923,7 +928,7 @@ class OidcHandler(BaseHandler):
             extra_attributes = await get_extra_attributes(userinfo, token)
 
         await self._sso_handler.complete_sso_login_request(
-            self._auth_provider_id,
+            self.idp_id,
             remote_user_id,
             request,
             client_redirect_url,
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 5fa7ab3f8b..6106237f1f 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -73,27 +73,38 @@ class SamlHandler(BaseHandler):
         )
 
         # identifier for the external_ids table
-        self._auth_provider_id = "saml"
+        self.idp_id = "saml"
 
         # a map from saml session id to Saml2SessionData object
         self._outstanding_requests_dict = {}  # type: Dict[str, Saml2SessionData]
 
         self._sso_handler = hs.get_sso_handler()
+        self._sso_handler.register_identity_provider(self)
 
-    def handle_redirect_request(
-        self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
-    ) -> bytes:
+    async def handle_redirect_request(
+        self,
+        request: SynapseRequest,
+        client_redirect_url: Optional[bytes],
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
         """Handle an incoming request to /login/sso/redirect
 
         Args:
+            request: the incoming HTTP request
             client_redirect_url: the URL that we should redirect the
-                client to when everything is done
+                client to after login (or None for UI Auth).
             ui_auth_session_id: The session ID of the ongoing UI Auth (or
                 None if this is a login).
 
         Returns:
             URL to redirect to
         """
+        if not client_redirect_url:
+            # Some SAML identity providers (e.g. Google) require a
+            # RelayState parameter on requests, so pass in a dummy redirect URL
+            # (which will never get used).
+            client_redirect_url = b"unused"
+
         reqid, info = self._saml_client.prepare_for_authenticate(
             entityid=self._saml_idp_entityid, relay_state=client_redirect_url
         )
@@ -210,7 +221,7 @@ class SamlHandler(BaseHandler):
                 return
 
             return await self._sso_handler.complete_sso_ui_auth_request(
-                self._auth_provider_id,
+                self.idp_id,
                 remote_user_id,
                 current_session.ui_auth_session_id,
                 request,
@@ -306,7 +317,7 @@ class SamlHandler(BaseHandler):
             return None
 
         await self._sso_handler.complete_sso_login_request(
-            self._auth_provider_id,
+            self.idp_id,
             remote_user_id,
             request,
             client_redirect_url,
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 33cd6bc178..d8fb8cdd05 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -12,15 +12,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import abc
 import logging
 from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
 
 import attr
-from typing_extensions import NoReturn
+from typing_extensions import NoReturn, Protocol
 
 from twisted.web.http import Request
 
-from synapse.api.errors import RedirectException, SynapseError
+from synapse.api.errors import Codes, RedirectException, SynapseError
 from synapse.http.server import respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
@@ -40,6 +41,53 @@ class MappingException(Exception):
     """
 
 
+class SsoIdentityProvider(Protocol):
+    """Abstract base class to be implemented by SSO Identity Providers
+
+    An Identity Provider, or IdP, is an external HTTP service which authenticates a user
+    to say whether they should be allowed to log in, or perform a given action.
+
+    Synapse supports various implementations of IdPs, including OpenID Connect, SAML,
+    and CAS.
+
+    The main entry point is `handle_redirect_request`, which should return a URI to
+    redirect the user's browser to the IdP's authentication page.
+
+    Each IdP should be registered with the SsoHandler via
+    `hs.get_sso_handler().register_identity_provider()`, so that requests to
+    `/_matrix/client/r0/login/sso/redirect` can be correctly dispatched.
+    """
+
+    @property
+    @abc.abstractmethod
+    def idp_id(self) -> str:
+        """A unique identifier for this SSO provider
+
+        Eg, "saml", "cas", "github"
+        """
+
+    @abc.abstractmethod
+    async def handle_redirect_request(
+        self,
+        request: SynapseRequest,
+        client_redirect_url: Optional[bytes],
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
+        """Handle an incoming request to /login/sso/redirect
+
+        Args:
+            request: the incoming HTTP request
+            client_redirect_url: the URL that we should redirect the
+                client to after login (or None for UI Auth).
+            ui_auth_session_id: The session ID of the ongoing UI Auth (or
+                None if this is a login).
+
+        Returns:
+            URL to redirect to
+        """
+        raise NotImplementedError()
+
+
 @attr.s
 class UserAttributes:
     # the localpart of the mxid that the mapper has assigned to the user.
@@ -100,6 +148,14 @@ class SsoHandler:
         # a map from session id to session data
         self._username_mapping_sessions = {}  # type: Dict[str, UsernameMappingSession]
 
+        # map from idp_id to SsoIdentityProvider
+        self._identity_providers = {}  # type: Dict[str, SsoIdentityProvider]
+
+    def register_identity_provider(self, p: SsoIdentityProvider):
+        p_id = p.idp_id
+        assert p_id not in self._identity_providers
+        self._identity_providers[p_id] = p
+
     def render_error(
         self,
         request: Request,
@@ -124,6 +180,32 @@ class SsoHandler:
         )
         respond_with_html(request, code, html)
 
+    async def handle_redirect_request(
+        self, request: SynapseRequest, client_redirect_url: bytes,
+    ) -> str:
+        """Handle a request to /login/sso/redirect
+
+        Args:
+            request: incoming HTTP request
+            client_redirect_url: the URL that we should redirect the
+                client to after login.
+
+        Returns:
+             the URI to redirect to
+        """
+        if not self._identity_providers:
+            raise SynapseError(
+                400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED
+            )
+
+        # if we only have one auth provider, redirect to it directly
+        if len(self._identity_providers) == 1:
+            ap = next(iter(self._identity_providers.values()))
+            return await ap.handle_redirect_request(request, client_redirect_url)
+
+        # otherwise, we have a configuration error
+        raise Exception("Multiple SSO identity providers have been configured!")
+
     async def get_sso_user_by_remote_user_id(
         self, auth_provider_id: str, remote_user_id: str
     ) -> Optional[str]:
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 5f4c6703db..ebc346105b 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -311,48 +311,31 @@ class LoginRestServlet(RestServlet):
         return result
 
 
-class BaseSSORedirectServlet(RestServlet):
-    """Common base class for /login/sso/redirect impls"""
-
+class SsoRedirectServlet(RestServlet):
     PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
 
+    def __init__(self, hs: "HomeServer"):
+        # make sure that the relevant handlers are instantiated, so that they
+        # register themselves with the main SSOHandler.
+        if hs.config.cas_enabled:
+            hs.get_cas_handler()
+        elif hs.config.saml2_enabled:
+            hs.get_saml_handler()
+        elif hs.config.oidc_enabled:
+            hs.get_oidc_handler()
+        self._sso_handler = hs.get_sso_handler()
+
     async def on_GET(self, request: SynapseRequest):
-        args = request.args
-        if b"redirectUrl" not in args:
-            return 400, "Redirect URL not specified for SSO auth"
-        client_redirect_url = args[b"redirectUrl"][0]
-        sso_url = await self.get_sso_url(request, client_redirect_url)
+        client_redirect_url = parse_string(
+            request, "redirectUrl", required=True, encoding=None
+        )
+        sso_url = await self._sso_handler.handle_redirect_request(
+            request, client_redirect_url
+        )
+        logger.info("Redirecting to %s", sso_url)
         request.redirect(sso_url)
         finish_request(request)
 
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        """Get the URL to redirect to, to perform SSO auth
-
-        Args:
-            request: The client request to redirect.
-            client_redirect_url: the URL that we should redirect the
-                client to when everything is done
-
-        Returns:
-            URL to redirect to
-        """
-        # to be implemented by subclasses
-        raise NotImplementedError()
-
-
-class CasRedirectServlet(BaseSSORedirectServlet):
-    def __init__(self, hs):
-        self._cas_handler = hs.get_cas_handler()
-
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        return self._cas_handler.get_redirect_url(
-            {"redirectUrl": client_redirect_url}
-        ).encode("ascii")
-
 
 class CasTicketServlet(RestServlet):
     PATTERNS = client_patterns("/login/cas/ticket", v1=True)
@@ -379,40 +362,8 @@ class CasTicketServlet(RestServlet):
         )
 
 
-class SAMLRedirectServlet(BaseSSORedirectServlet):
-    PATTERNS = client_patterns("/login/sso/redirect", v1=True)
-
-    def __init__(self, hs):
-        self._saml_handler = hs.get_saml_handler()
-
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        return self._saml_handler.handle_redirect_request(client_redirect_url)
-
-
-class OIDCRedirectServlet(BaseSSORedirectServlet):
-    """Implementation for /login/sso/redirect for the OIDC login flow."""
-
-    PATTERNS = client_patterns("/login/sso/redirect", v1=True)
-
-    def __init__(self, hs):
-        self._oidc_handler = hs.get_oidc_handler()
-
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        return await self._oidc_handler.handle_redirect_request(
-            request, client_redirect_url
-        )
-
-
 def register_servlets(hs, http_server):
     LoginRestServlet(hs).register(http_server)
+    SsoRedirectServlet(hs).register(http_server)
     if hs.config.cas_enabled:
-        CasRedirectServlet(hs).register(http_server)
         CasTicketServlet(hs).register(http_server)
-    elif hs.config.saml2_enabled:
-        SAMLRedirectServlet(hs).register(http_server)
-    elif hs.config.oidc_enabled:
-        OIDCRedirectServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index fab077747f..9b9514632f 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -14,15 +14,20 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import SynapseError
 from synapse.api.urls import CLIENT_API_PREFIX
+from synapse.handlers.sso import SsoIdentityProvider
 from synapse.http.server import respond_with_html
 from synapse.http.servlet import RestServlet, parse_string
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -35,7 +40,7 @@ class AuthRestServlet(RestServlet):
 
     PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
@@ -85,31 +90,20 @@ class AuthRestServlet(RestServlet):
         elif stagetype == LoginType.SSO:
             # Display a confirmation page which prompts the user to
             # re-authenticate with their SSO provider.
-            if self._cas_enabled:
-                # Generate a request to CAS that redirects back to an endpoint
-                # to verify the successful authentication.
-                sso_redirect_url = self._cas_handler.get_redirect_url(
-                    {"session": session},
-                )
 
+            if self._cas_enabled:
+                sso_auth_provider = self._cas_handler  # type: SsoIdentityProvider
             elif self._saml_enabled:
-                # Some SAML identity providers (e.g. Google) require a
-                # RelayState parameter on requests. It is not necessary here, so
-                # pass in a dummy redirect URL (which will never get used).
-                client_redirect_url = b"unused"
-                sso_redirect_url = self._saml_handler.handle_redirect_request(
-                    client_redirect_url, session
-                )
-
+                sso_auth_provider = self._saml_handler
             elif self._oidc_enabled:
-                client_redirect_url = b""
-                sso_redirect_url = await self._oidc_handler.handle_redirect_request(
-                    request, client_redirect_url, session
-                )
-
+                sso_auth_provider = self._oidc_handler
             else:
                 raise SynapseError(400, "Homeserver not configured for SSO.")
 
+            sso_redirect_url = await sso_auth_provider.handle_redirect_request(
+                request, None, session
+            )
+
             html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
 
         else:
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 18932d7518..999d628315 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -385,7 +385,7 @@ class CASTestCase(unittest.HomeserverTestCase):
         channel = self.make_request("GET", cas_ticket_url)
 
         # Test that the response is HTML.
-        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.code, 200, channel.result)
         content_type_header_value = ""
         for header in channel.result.get("headers", []):
             if header[0] == b"Content-Type":