diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 178f263439..4ba8c7fda5 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -311,7 +311,7 @@ class OidcHandler:
``ClientAuth`` to authenticate with the client with its ID and secret.
Args:
- code: The autorization code we got from the callback.
+ code: The authorization code we got from the callback.
Returns:
A dict containing various tokens.
@@ -497,11 +497,14 @@ class OidcHandler:
return UserInfo(claims)
async def handle_redirect_request(
- self, request: SynapseRequest, client_redirect_url: bytes
- ) -> None:
+ self,
+ request: SynapseRequest,
+ client_redirect_url: bytes,
+ ui_auth_session_id: Optional[str] = None,
+ ) -> str:
"""Handle an incoming request to /login/sso/redirect
- It redirects the browser to the authorization endpoint with a few
+ It returns a redirect to the authorization endpoint with a few
parameters:
- ``client_id``: the client ID set in ``oidc_config.client_id``
@@ -511,24 +514,32 @@ class OidcHandler:
- ``state``: a random string
- ``nonce``: a random string
- In addition to redirecting the client, we are setting a cookie with
+ In addition generating a redirect URL, we are setting a cookie with
a signed macaroon token containing the state, the nonce and the
client_redirect_url params. Those are then checked when the client
comes back from the provider.
-
Args:
request: the incoming request from the browser.
We'll respond to it with a redirect and a cookie.
client_redirect_url: the URL that we should redirect the client to
when everything is done
+ ui_auth_session_id: The session ID of the ongoing UI Auth (or
+ None if this is a login).
+
+ Returns:
+ The redirect URL to the authorization endpoint.
+
"""
state = generate_token()
nonce = generate_token()
cookie = self._generate_oidc_session_token(
- state=state, nonce=nonce, client_redirect_url=client_redirect_url.decode(),
+ state=state,
+ nonce=nonce,
+ client_redirect_url=client_redirect_url.decode(),
+ ui_auth_session_id=ui_auth_session_id,
)
request.addCookie(
SESSION_COOKIE_NAME,
@@ -541,7 +552,7 @@ class OidcHandler:
metadata = await self.load_metadata()
authorization_endpoint = metadata.get("authorization_endpoint")
- uri = prepare_grant_uri(
+ return prepare_grant_uri(
authorization_endpoint,
client_id=self._client_auth.client_id,
response_type="code",
@@ -550,8 +561,6 @@ class OidcHandler:
state=state,
nonce=nonce,
)
- request.redirect(uri)
- finish_request(request)
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
"""Handle an incoming request to /_synapse/oidc/callback
@@ -625,7 +634,11 @@ class OidcHandler:
# Deserialize the session token and verify it.
try:
- nonce, client_redirect_url = self._verify_oidc_session_token(session, state)
+ (
+ nonce,
+ client_redirect_url,
+ ui_auth_session_id,
+ ) = self._verify_oidc_session_token(session, state)
except MacaroonDeserializationException as e:
logger.exception("Invalid session")
self._render_error(request, "invalid_session", str(e))
@@ -678,15 +691,21 @@ class OidcHandler:
return
# and finally complete the login
- await self._auth_handler.complete_sso_login(
- user_id, request, client_redirect_url
- )
+ if ui_auth_session_id:
+ await self._auth_handler.complete_sso_ui_auth(
+ user_id, ui_auth_session_id, request
+ )
+ else:
+ await self._auth_handler.complete_sso_login(
+ user_id, request, client_redirect_url
+ )
def _generate_oidc_session_token(
self,
state: str,
nonce: str,
client_redirect_url: str,
+ ui_auth_session_id: Optional[str],
duration_in_ms: int = (60 * 60 * 1000),
) -> str:
"""Generates a signed token storing data about an OIDC session.
@@ -702,6 +721,8 @@ class OidcHandler:
nonce: The ``nonce`` parameter passed to the OIDC provider.
client_redirect_url: The URL the client gave when it initiated the
flow.
+ ui_auth_session_id: The session ID of the ongoing UI Auth (or
+ None if this is a login).
duration_in_ms: An optional duration for the token in milliseconds.
Defaults to an hour.
@@ -718,12 +739,19 @@ class OidcHandler:
macaroon.add_first_party_caveat(
"client_redirect_url = %s" % (client_redirect_url,)
)
+ if ui_auth_session_id:
+ macaroon.add_first_party_caveat(
+ "ui_auth_session_id = %s" % (ui_auth_session_id,)
+ )
now = self._clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
+
return macaroon.serialize()
- def _verify_oidc_session_token(self, session: str, state: str) -> Tuple[str, str]:
+ def _verify_oidc_session_token(
+ self, session: str, state: str
+ ) -> Tuple[str, str, Optional[str]]:
"""Verifies and extract an OIDC session token.
This verifies that a given session token was issued by this homeserver
@@ -734,7 +762,7 @@ class OidcHandler:
state: The state the OIDC provider gave back
Returns:
- The nonce and the client_redirect_url for this session
+ The nonce, client_redirect_url, and ui_auth_session_id for this session
"""
macaroon = pymacaroons.Macaroon.deserialize(session)
@@ -744,17 +772,27 @@ class OidcHandler:
v.satisfy_exact("state = %s" % (state,))
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
+ # Sometimes there's a UI auth session ID, it seems to be OK to attempt
+ # to always satisfy this.
+ v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
v.satisfy_general(self._verify_expiry)
v.verify(macaroon, self._macaroon_secret_key)
- # Extract the `nonce` and `client_redirect_url` from the token
+ # Extract the `nonce`, `client_redirect_url`, and maybe the
+ # `ui_auth_session_id` from the token.
nonce = self._get_value_from_macaroon(macaroon, "nonce")
client_redirect_url = self._get_value_from_macaroon(
macaroon, "client_redirect_url"
)
+ try:
+ ui_auth_session_id = self._get_value_from_macaroon(
+ macaroon, "ui_auth_session_id"
+ ) # type: Optional[str]
+ except ValueError:
+ ui_auth_session_id = None
- return nonce, client_redirect_url
+ return nonce, client_redirect_url, ui_auth_session_id
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
"""Extracts a caveat value from a macaroon token.
@@ -773,7 +811,7 @@ class OidcHandler:
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(prefix):
return caveat.caveat_id[len(prefix) :]
- raise Exception("No %s caveat in macaroon" % (key,))
+ raise ValueError("No %s caveat in macaroon" % (key,))
def _verify_expiry(self, caveat: str) -> bool:
prefix = "time < "
|