summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2021-03-04 14:44:22 +0000
committerGitHub <noreply@github.com>2021-03-04 14:44:22 +0000
commit7eb6e39a8fe9d42a411cefd905cf2caa29896923 (patch)
treeddcf4fc4eb801299d2e6191c7f34af2d3741c066 /synapse/handlers
parentFix link in UPGRADES (diff)
downloadsynapse-7eb6e39a8fe9d42a411cefd905cf2caa29896923.tar.xz
Record the SSO Auth Provider in the login token (#9510)
This great big stack of commits is a a whole load of hoop-jumping to make it easier to store additional values in login tokens, and then to actually store the SSO Identity Provider in the login token. (Making use of that data will follow in a subsequent PR.)
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/auth.py68
-rw-r--r--synapse/handlers/oidc_handler.py65
-rw-r--r--synapse/handlers/sso.py2
3 files changed, 74 insertions, 61 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 3978e41518..bec0c615d4 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -65,6 +65,7 @@ from synapse.storage.roommember import ProfileInfo
 from synapse.types import JsonDict, Requester, UserID
 from synapse.util import stringutils as stringutils
 from synapse.util.async_helpers import maybe_awaitable
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.threepids import canonicalise_email
 
@@ -170,6 +171,16 @@ class SsoLoginExtraAttributes:
     extra_attributes = attr.ib(type=JsonDict)
 
 
+@attr.s(slots=True, frozen=True)
+class LoginTokenAttributes:
+    """Data we store in a short-term login token"""
+
+    user_id = attr.ib(type=str)
+
+    # the SSO Identity Provider that the user authenticated with, to get this token
+    auth_provider_id = attr.ib(type=str)
+
+
 class AuthHandler(BaseHandler):
     SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
 
@@ -1164,18 +1175,16 @@ class AuthHandler(BaseHandler):
             return None
         return user_id
 
-    async def validate_short_term_login_token_and_get_user_id(self, login_token: str):
-        auth_api = self.hs.get_auth()
-        user_id = None
+    async def validate_short_term_login_token(
+        self, login_token: str
+    ) -> LoginTokenAttributes:
         try:
-            macaroon = pymacaroons.Macaroon.deserialize(login_token)
-            user_id = auth_api.get_user_id_from_macaroon(macaroon)
-            auth_api.validate_macaroon(macaroon, "login", user_id)
+            res = self.macaroon_gen.verify_short_term_login_token(login_token)
         except Exception:
             raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
 
-        await self.auth.check_auth_blocking(user_id)
-        return user_id
+        await self.auth.check_auth_blocking(res.user_id)
+        return res
 
     async def delete_access_token(self, access_token: str):
         """Invalidate a single access token
@@ -1397,6 +1406,7 @@ class AuthHandler(BaseHandler):
     async def complete_sso_login(
         self,
         registered_user_id: str,
+        auth_provider_id: str,
         request: Request,
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
@@ -1406,6 +1416,9 @@ class AuthHandler(BaseHandler):
 
         Args:
             registered_user_id: The registered user ID to complete SSO login for.
+            auth_provider_id: The id of the SSO Identity provider that was used for
+                login. This will be stored in the login token for future tracking in
+                prometheus metrics.
             request: The request to complete.
             client_redirect_url: The URL to which to redirect the user at the end of the
                 process.
@@ -1427,6 +1440,7 @@ class AuthHandler(BaseHandler):
 
         self._complete_sso_login(
             registered_user_id,
+            auth_provider_id,
             request,
             client_redirect_url,
             extra_attributes,
@@ -1437,6 +1451,7 @@ class AuthHandler(BaseHandler):
     def _complete_sso_login(
         self,
         registered_user_id: str,
+        auth_provider_id: str,
         request: Request,
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
@@ -1463,7 +1478,7 @@ class AuthHandler(BaseHandler):
 
         # Create a login token
         login_token = self.macaroon_gen.generate_short_term_login_token(
-            registered_user_id
+            registered_user_id, auth_provider_id=auth_provider_id
         )
 
         # Append the login token to the original redirect URL (i.e. with its query
@@ -1569,15 +1584,48 @@ class MacaroonGenerator:
         return macaroon.serialize()
 
     def generate_short_term_login_token(
-        self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+        self,
+        user_id: str,
+        auth_provider_id: str,
+        duration_in_ms: int = (2 * 60 * 1000),
     ) -> str:
         macaroon = self._generate_base_macaroon(user_id)
         macaroon.add_first_party_caveat("type = login")
         now = self.hs.get_clock().time_msec()
         expiry = now + duration_in_ms
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
+        macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
         return macaroon.serialize()
 
+    def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
+        """Verify a short-term-login macaroon
+
+        Checks that the given token is a valid, unexpired short-term-login token
+        minted by this server.
+
+        Args:
+            token: the login token to verify
+
+        Returns:
+            the user_id that this token is valid for
+
+        Raises:
+            MacaroonVerificationFailedException if the verification failed
+        """
+        macaroon = pymacaroons.Macaroon.deserialize(token)
+        user_id = get_value_from_macaroon(macaroon, "user_id")
+        auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
+
+        v = pymacaroons.Verifier()
+        v.satisfy_exact("gen = 1")
+        v.satisfy_exact("type = login")
+        v.satisfy_general(lambda c: c.startswith("user_id = "))
+        v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
+        satisfy_expiry(v, self.hs.get_clock().time_msec)
+        v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
+
+        return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
+
     def generate_delete_pusher_token(self, user_id: str) -> str:
         macaroon = self._generate_base_macaroon(user_id)
         macaroon.add_first_party_caveat("type = delete_pusher")
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 07db1e31e4..b4a74390cc 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -42,6 +42,7 @@ from synapse.logging.context import make_deferred_yieldable
 from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
 from synapse.util import json_decoder
 from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -211,7 +212,7 @@ class OidcHandler:
             session_data = self._token_generator.verify_oidc_session_token(
                 session, state
             )
-        except (MacaroonDeserializationException, ValueError) as e:
+        except (MacaroonDeserializationException, KeyError) as e:
             logger.exception("Invalid session for OIDC callback")
             self._sso_handler.render_error(request, "invalid_session", str(e))
             return
@@ -745,7 +746,7 @@ class OidcProvider:
                 idp_id=self.idp_id,
                 nonce=nonce,
                 client_redirect_url=client_redirect_url.decode(),
-                ui_auth_session_id=ui_auth_session_id,
+                ui_auth_session_id=ui_auth_session_id or "",
             ),
         )
 
@@ -1020,10 +1021,9 @@ class OidcSessionTokenGenerator:
         macaroon.add_first_party_caveat(
             "client_redirect_url = %s" % (session_data.client_redirect_url,)
         )
-        if session_data.ui_auth_session_id:
-            macaroon.add_first_party_caveat(
-                "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
-            )
+        macaroon.add_first_party_caveat(
+            "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
+        )
         now = self._clock.time_msec()
         expiry = now + duration_in_ms
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
@@ -1046,7 +1046,7 @@ class OidcSessionTokenGenerator:
             The data extracted from the session cookie
 
         Raises:
-            ValueError if an expected caveat is missing from the macaroon.
+            KeyError if an expected caveat is missing from the macaroon.
         """
         macaroon = pymacaroons.Macaroon.deserialize(session)
 
@@ -1057,26 +1057,16 @@ class OidcSessionTokenGenerator:
         v.satisfy_general(lambda c: c.startswith("nonce = "))
         v.satisfy_general(lambda c: c.startswith("idp_id = "))
         v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
-        # Sometimes there's a UI auth session ID, it seems to be OK to attempt
-        # to always satisfy this.
         v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
-        v.satisfy_general(self._verify_expiry)
+        satisfy_expiry(v, self._clock.time_msec)
 
         v.verify(macaroon, self._macaroon_secret_key)
 
         # Extract the session data from the token.
-        nonce = self._get_value_from_macaroon(macaroon, "nonce")
-        idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
-        client_redirect_url = self._get_value_from_macaroon(
-            macaroon, "client_redirect_url"
-        )
-        try:
-            ui_auth_session_id = self._get_value_from_macaroon(
-                macaroon, "ui_auth_session_id"
-            )  # type: Optional[str]
-        except ValueError:
-            ui_auth_session_id = None
-
+        nonce = get_value_from_macaroon(macaroon, "nonce")
+        idp_id = get_value_from_macaroon(macaroon, "idp_id")
+        client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url")
+        ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id")
         return OidcSessionData(
             nonce=nonce,
             idp_id=idp_id,
@@ -1084,33 +1074,6 @@ class OidcSessionTokenGenerator:
             ui_auth_session_id=ui_auth_session_id,
         )
 
-    def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
-        """Extracts a caveat value from a macaroon token.
-
-        Args:
-            macaroon: the token
-            key: the key of the caveat to extract
-
-        Returns:
-            The extracted value
-
-        Raises:
-            ValueError: if the caveat was not in the macaroon
-        """
-        prefix = key + " = "
-        for caveat in macaroon.caveats:
-            if caveat.caveat_id.startswith(prefix):
-                return caveat.caveat_id[len(prefix) :]
-        raise ValueError("No %s caveat in macaroon" % (key,))
-
-    def _verify_expiry(self, caveat: str) -> bool:
-        prefix = "time < "
-        if not caveat.startswith(prefix):
-            return False
-        expiry = int(caveat[len(prefix) :])
-        now = self._clock.time_msec()
-        return now < expiry
-
 
 @attr.s(frozen=True, slots=True)
 class OidcSessionData:
@@ -1125,8 +1088,8 @@ class OidcSessionData:
     # The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
     client_redirect_url = attr.ib(type=str)
 
-    # The session ID of the ongoing UI Auth (None if this is a login)
-    ui_auth_session_id = attr.ib(type=Optional[str], default=None)
+    # The session ID of the ongoing UI Auth ("" if this is a login)
+    ui_auth_session_id = attr.ib(type=str)
 
 
 UserAttributeDict = TypedDict(
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 80e28bdcbe..8a22dab54a 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -456,6 +456,7 @@ class SsoHandler:
 
         await self._auth_handler.complete_sso_login(
             user_id,
+            auth_provider_id,
             request,
             client_redirect_url,
             extra_login_attributes,
@@ -886,6 +887,7 @@ class SsoHandler:
 
         await self._auth_handler.complete_sso_login(
             user_id,
+            session.auth_provider_id,
             request,
             session.client_redirect_url,
             session.extra_login_attributes,