summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/auth.py4
-rw-r--r--synapse/handlers/oidc_handler.py76
-rw-r--r--synapse/rest/client/v1/login.py31
-rw-r--r--synapse/rest/client/v2_alpha/auth.py19
4 files changed, 94 insertions, 36 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 524281d2f1..75b39e878c 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -80,7 +80,9 @@ class AuthHandler(BaseHandler):
         self.hs = hs  # FIXME better possibility to access registrationHandler later?
         self.macaroon_gen = hs.get_macaroon_generator()
         self._password_enabled = hs.config.password_enabled
-        self._sso_enabled = hs.config.saml2_enabled or hs.config.cas_enabled
+        self._sso_enabled = (
+            hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
+        )
 
         # we keep this as a list despite the O(N^2) implication so that we can
         # keep PASSWORD first and avoid confusing clients which pick the first
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 178f263439..4ba8c7fda5 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -311,7 +311,7 @@ class OidcHandler:
         ``ClientAuth`` to authenticate with the client with its ID and secret.
 
         Args:
-            code: The autorization code we got from the callback.
+            code: The authorization code we got from the callback.
 
         Returns:
             A dict containing various tokens.
@@ -497,11 +497,14 @@ class OidcHandler:
         return UserInfo(claims)
 
     async def handle_redirect_request(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> None:
+        self,
+        request: SynapseRequest,
+        client_redirect_url: bytes,
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
         """Handle an incoming request to /login/sso/redirect
 
-        It redirects the browser to the authorization endpoint with a few
+        It returns a redirect to the authorization endpoint with a few
         parameters:
 
           - ``client_id``: the client ID set in ``oidc_config.client_id``
@@ -511,24 +514,32 @@ class OidcHandler:
           - ``state``: a random string
           - ``nonce``: a random string
 
-        In addition to redirecting the client, we are setting a cookie with
+        In addition generating a redirect URL, we are setting a cookie with
         a signed macaroon token containing the state, the nonce and the
         client_redirect_url params. Those are then checked when the client
         comes back from the provider.
 
-
         Args:
             request: the incoming request from the browser.
                 We'll respond to it with a redirect and a cookie.
             client_redirect_url: the URL that we should redirect the client to
                 when everything is done
+            ui_auth_session_id: The session ID of the ongoing UI Auth (or
+                None if this is a login).
+
+        Returns:
+            The redirect URL to the authorization endpoint.
+
         """
 
         state = generate_token()
         nonce = generate_token()
 
         cookie = self._generate_oidc_session_token(
-            state=state, nonce=nonce, client_redirect_url=client_redirect_url.decode(),
+            state=state,
+            nonce=nonce,
+            client_redirect_url=client_redirect_url.decode(),
+            ui_auth_session_id=ui_auth_session_id,
         )
         request.addCookie(
             SESSION_COOKIE_NAME,
@@ -541,7 +552,7 @@ class OidcHandler:
 
         metadata = await self.load_metadata()
         authorization_endpoint = metadata.get("authorization_endpoint")
-        uri = prepare_grant_uri(
+        return prepare_grant_uri(
             authorization_endpoint,
             client_id=self._client_auth.client_id,
             response_type="code",
@@ -550,8 +561,6 @@ class OidcHandler:
             state=state,
             nonce=nonce,
         )
-        request.redirect(uri)
-        finish_request(request)
 
     async def handle_oidc_callback(self, request: SynapseRequest) -> None:
         """Handle an incoming request to /_synapse/oidc/callback
@@ -625,7 +634,11 @@ class OidcHandler:
 
         # Deserialize the session token and verify it.
         try:
-            nonce, client_redirect_url = self._verify_oidc_session_token(session, state)
+            (
+                nonce,
+                client_redirect_url,
+                ui_auth_session_id,
+            ) = self._verify_oidc_session_token(session, state)
         except MacaroonDeserializationException as e:
             logger.exception("Invalid session")
             self._render_error(request, "invalid_session", str(e))
@@ -678,15 +691,21 @@ class OidcHandler:
             return
 
         # and finally complete the login
-        await self._auth_handler.complete_sso_login(
-            user_id, request, client_redirect_url
-        )
+        if ui_auth_session_id:
+            await self._auth_handler.complete_sso_ui_auth(
+                user_id, ui_auth_session_id, request
+            )
+        else:
+            await self._auth_handler.complete_sso_login(
+                user_id, request, client_redirect_url
+            )
 
     def _generate_oidc_session_token(
         self,
         state: str,
         nonce: str,
         client_redirect_url: str,
+        ui_auth_session_id: Optional[str],
         duration_in_ms: int = (60 * 60 * 1000),
     ) -> str:
         """Generates a signed token storing data about an OIDC session.
@@ -702,6 +721,8 @@ class OidcHandler:
             nonce: The ``nonce`` parameter passed to the OIDC provider.
             client_redirect_url: The URL the client gave when it initiated the
                 flow.
+            ui_auth_session_id: The session ID of the ongoing UI Auth (or
+                None if this is a login).
             duration_in_ms: An optional duration for the token in milliseconds.
                 Defaults to an hour.
 
@@ -718,12 +739,19 @@ class OidcHandler:
         macaroon.add_first_party_caveat(
             "client_redirect_url = %s" % (client_redirect_url,)
         )
+        if ui_auth_session_id:
+            macaroon.add_first_party_caveat(
+                "ui_auth_session_id = %s" % (ui_auth_session_id,)
+            )
         now = self._clock.time_msec()
         expiry = now + duration_in_ms
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
+
         return macaroon.serialize()
 
-    def _verify_oidc_session_token(self, session: str, state: str) -> Tuple[str, str]:
+    def _verify_oidc_session_token(
+        self, session: str, state: str
+    ) -> Tuple[str, str, Optional[str]]:
         """Verifies and extract an OIDC session token.
 
         This verifies that a given session token was issued by this homeserver
@@ -734,7 +762,7 @@ class OidcHandler:
             state: The state the OIDC provider gave back
 
         Returns:
-            The nonce and the client_redirect_url for this session
+            The nonce, client_redirect_url, and ui_auth_session_id for this session
         """
         macaroon = pymacaroons.Macaroon.deserialize(session)
 
@@ -744,17 +772,27 @@ class OidcHandler:
         v.satisfy_exact("state = %s" % (state,))
         v.satisfy_general(lambda c: c.startswith("nonce = "))
         v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
+        # Sometimes there's a UI auth session ID, it seems to be OK to attempt
+        # to always satisfy this.
+        v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
         v.satisfy_general(self._verify_expiry)
 
         v.verify(macaroon, self._macaroon_secret_key)
 
-        # Extract the `nonce` and `client_redirect_url` from the token
+        # Extract the `nonce`, `client_redirect_url`, and maybe the
+        # `ui_auth_session_id` from the token.
         nonce = self._get_value_from_macaroon(macaroon, "nonce")
         client_redirect_url = self._get_value_from_macaroon(
             macaroon, "client_redirect_url"
         )
+        try:
+            ui_auth_session_id = self._get_value_from_macaroon(
+                macaroon, "ui_auth_session_id"
+            )  # type: Optional[str]
+        except ValueError:
+            ui_auth_session_id = None
 
-        return nonce, client_redirect_url
+        return nonce, client_redirect_url, ui_auth_session_id
 
     def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
         """Extracts a caveat value from a macaroon token.
@@ -773,7 +811,7 @@ class OidcHandler:
         for caveat in macaroon.caveats:
             if caveat.caveat_id.startswith(prefix):
                 return caveat.caveat_id[len(prefix) :]
-        raise Exception("No %s caveat in macaroon" % (key,))
+        raise ValueError("No %s caveat in macaroon" % (key,))
 
     def _verify_expiry(self, caveat: str) -> bool:
         prefix = "time < "
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index de7eca21f8..d89b2e5532 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -401,19 +401,22 @@ class BaseSSORedirectServlet(RestServlet):
 
     PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
 
-    def on_GET(self, request: SynapseRequest):
+    async def on_GET(self, request: SynapseRequest):
         args = request.args
         if b"redirectUrl" not in args:
             return 400, "Redirect URL not specified for SSO auth"
         client_redirect_url = args[b"redirectUrl"][0]
-        sso_url = self.get_sso_url(client_redirect_url)
+        sso_url = await self.get_sso_url(request, client_redirect_url)
         request.redirect(sso_url)
         finish_request(request)
 
-    def get_sso_url(self, client_redirect_url: bytes) -> bytes:
+    async def get_sso_url(
+        self, request: SynapseRequest, client_redirect_url: bytes
+    ) -> bytes:
         """Get the URL to redirect to, to perform SSO auth
 
         Args:
+            request: The client request to redirect.
             client_redirect_url: the URL that we should redirect the
                 client to when everything is done
 
@@ -428,7 +431,9 @@ class CasRedirectServlet(BaseSSORedirectServlet):
     def __init__(self, hs):
         self._cas_handler = hs.get_cas_handler()
 
-    def get_sso_url(self, client_redirect_url: bytes) -> bytes:
+    async def get_sso_url(
+        self, request: SynapseRequest, client_redirect_url: bytes
+    ) -> bytes:
         return self._cas_handler.get_redirect_url(
             {"redirectUrl": client_redirect_url}
         ).encode("ascii")
@@ -465,11 +470,13 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
     def __init__(self, hs):
         self._saml_handler = hs.get_saml_handler()
 
-    def get_sso_url(self, client_redirect_url: bytes) -> bytes:
+    async def get_sso_url(
+        self, request: SynapseRequest, client_redirect_url: bytes
+    ) -> bytes:
         return self._saml_handler.handle_redirect_request(client_redirect_url)
 
 
-class OIDCRedirectServlet(RestServlet):
+class OIDCRedirectServlet(BaseSSORedirectServlet):
     """Implementation for /login/sso/redirect for the OIDC login flow."""
 
     PATTERNS = client_patterns("/login/sso/redirect", v1=True)
@@ -477,12 +484,12 @@ class OIDCRedirectServlet(RestServlet):
     def __init__(self, hs):
         self._oidc_handler = hs.get_oidc_handler()
 
-    async def on_GET(self, request):
-        args = request.args
-        if b"redirectUrl" not in args:
-            return 400, "Redirect URL not specified for SSO auth"
-        client_redirect_url = args[b"redirectUrl"][0]
-        await self._oidc_handler.handle_redirect_request(request, client_redirect_url)
+    async def get_sso_url(
+        self, request: SynapseRequest, client_redirect_url: bytes
+    ) -> bytes:
+        return await self._oidc_handler.handle_redirect_request(
+            request, client_redirect_url
+        )
 
 
 def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 24dd3d3e96..7bca1326d5 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -131,14 +131,19 @@ class AuthRestServlet(RestServlet):
         self.registration_handler = hs.get_registration_handler()
 
         # SSO configuration.
-        self._saml_enabled = hs.config.saml2_enabled
-        if self._saml_enabled:
-            self._saml_handler = hs.get_saml_handler()
         self._cas_enabled = hs.config.cas_enabled
         if self._cas_enabled:
             self._cas_handler = hs.get_cas_handler()
             self._cas_server_url = hs.config.cas_server_url
             self._cas_service_url = hs.config.cas_service_url
+        self._saml_enabled = hs.config.saml2_enabled
+        if self._saml_enabled:
+            self._saml_handler = hs.get_saml_handler()
+        self._oidc_enabled = hs.config.oidc_enabled
+        if self._oidc_enabled:
+            self._oidc_handler = hs.get_oidc_handler()
+            self._cas_server_url = hs.config.cas_server_url
+            self._cas_service_url = hs.config.cas_service_url
 
     async def on_GET(self, request, stagetype):
         session = parse_string(request, "session")
@@ -172,11 +177,17 @@ class AuthRestServlet(RestServlet):
                 )
 
             elif self._saml_enabled:
-                client_redirect_url = ""
+                client_redirect_url = b""
                 sso_redirect_url = self._saml_handler.handle_redirect_request(
                     client_redirect_url, session
                 )
 
+            elif self._oidc_enabled:
+                client_redirect_url = b""
+                sso_redirect_url = await self._oidc_handler.handle_redirect_request(
+                    request, client_redirect_url, session
+                )
+
             else:
                 raise SynapseError(400, "Homeserver not configured for SSO.")