summary refs log tree commit diff
path: root/synapse/handlers/oidc_handler.py
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2021-03-04 14:44:22 +0000
committerGitHub <noreply@github.com>2021-03-04 14:44:22 +0000
commit7eb6e39a8fe9d42a411cefd905cf2caa29896923 (patch)
treeddcf4fc4eb801299d2e6191c7f34af2d3741c066 /synapse/handlers/oidc_handler.py
parentFix link in UPGRADES (diff)
downloadsynapse-7eb6e39a8fe9d42a411cefd905cf2caa29896923.tar.xz
Record the SSO Auth Provider in the login token (#9510)
This great big stack of commits is a a whole load of hoop-jumping to make it easier to store additional values in login tokens, and then to actually store the SSO Identity Provider in the login token. (Making use of that data will follow in a subsequent PR.)
Diffstat (limited to 'synapse/handlers/oidc_handler.py')
-rw-r--r--synapse/handlers/oidc_handler.py65
1 files changed, 14 insertions, 51 deletions
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 07db1e31e4..b4a74390cc 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -42,6 +42,7 @@ from synapse.logging.context import make_deferred_yieldable
 from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
 from synapse.util import json_decoder
 from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -211,7 +212,7 @@ class OidcHandler:
             session_data = self._token_generator.verify_oidc_session_token(
                 session, state
             )
-        except (MacaroonDeserializationException, ValueError) as e:
+        except (MacaroonDeserializationException, KeyError) as e:
             logger.exception("Invalid session for OIDC callback")
             self._sso_handler.render_error(request, "invalid_session", str(e))
             return
@@ -745,7 +746,7 @@ class OidcProvider:
                 idp_id=self.idp_id,
                 nonce=nonce,
                 client_redirect_url=client_redirect_url.decode(),
-                ui_auth_session_id=ui_auth_session_id,
+                ui_auth_session_id=ui_auth_session_id or "",
             ),
         )
 
@@ -1020,10 +1021,9 @@ class OidcSessionTokenGenerator:
         macaroon.add_first_party_caveat(
             "client_redirect_url = %s" % (session_data.client_redirect_url,)
         )
-        if session_data.ui_auth_session_id:
-            macaroon.add_first_party_caveat(
-                "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
-            )
+        macaroon.add_first_party_caveat(
+            "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
+        )
         now = self._clock.time_msec()
         expiry = now + duration_in_ms
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
@@ -1046,7 +1046,7 @@ class OidcSessionTokenGenerator:
             The data extracted from the session cookie
 
         Raises:
-            ValueError if an expected caveat is missing from the macaroon.
+            KeyError if an expected caveat is missing from the macaroon.
         """
         macaroon = pymacaroons.Macaroon.deserialize(session)
 
@@ -1057,26 +1057,16 @@ class OidcSessionTokenGenerator:
         v.satisfy_general(lambda c: c.startswith("nonce = "))
         v.satisfy_general(lambda c: c.startswith("idp_id = "))
         v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
-        # Sometimes there's a UI auth session ID, it seems to be OK to attempt
-        # to always satisfy this.
         v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
-        v.satisfy_general(self._verify_expiry)
+        satisfy_expiry(v, self._clock.time_msec)
 
         v.verify(macaroon, self._macaroon_secret_key)
 
         # Extract the session data from the token.
-        nonce = self._get_value_from_macaroon(macaroon, "nonce")
-        idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
-        client_redirect_url = self._get_value_from_macaroon(
-            macaroon, "client_redirect_url"
-        )
-        try:
-            ui_auth_session_id = self._get_value_from_macaroon(
-                macaroon, "ui_auth_session_id"
-            )  # type: Optional[str]
-        except ValueError:
-            ui_auth_session_id = None
-
+        nonce = get_value_from_macaroon(macaroon, "nonce")
+        idp_id = get_value_from_macaroon(macaroon, "idp_id")
+        client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url")
+        ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id")
         return OidcSessionData(
             nonce=nonce,
             idp_id=idp_id,
@@ -1084,33 +1074,6 @@ class OidcSessionTokenGenerator:
             ui_auth_session_id=ui_auth_session_id,
         )
 
-    def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
-        """Extracts a caveat value from a macaroon token.
-
-        Args:
-            macaroon: the token
-            key: the key of the caveat to extract
-
-        Returns:
-            The extracted value
-
-        Raises:
-            ValueError: if the caveat was not in the macaroon
-        """
-        prefix = key + " = "
-        for caveat in macaroon.caveats:
-            if caveat.caveat_id.startswith(prefix):
-                return caveat.caveat_id[len(prefix) :]
-        raise ValueError("No %s caveat in macaroon" % (key,))
-
-    def _verify_expiry(self, caveat: str) -> bool:
-        prefix = "time < "
-        if not caveat.startswith(prefix):
-            return False
-        expiry = int(caveat[len(prefix) :])
-        now = self._clock.time_msec()
-        return now < expiry
-
 
 @attr.s(frozen=True, slots=True)
 class OidcSessionData:
@@ -1125,8 +1088,8 @@ class OidcSessionData:
     # The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
     client_redirect_url = attr.ib(type=str)
 
-    # The session ID of the ongoing UI Auth (None if this is a login)
-    ui_auth_session_id = attr.ib(type=Optional[str], default=None)
+    # The session ID of the ongoing UI Auth ("" if this is a login)
+    ui_auth_session_id = attr.ib(type=str)
 
 
 UserAttributeDict = TypedDict(