summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2021-03-04 14:44:22 +0000
committerGitHub <noreply@github.com>2021-03-04 14:44:22 +0000
commit7eb6e39a8fe9d42a411cefd905cf2caa29896923 (patch)
treeddcf4fc4eb801299d2e6191c7f34af2d3741c066 /synapse/handlers/auth.py
parentFix link in UPGRADES (diff)
downloadsynapse-7eb6e39a8fe9d42a411cefd905cf2caa29896923.tar.xz
Record the SSO Auth Provider in the login token (#9510)
This great big stack of commits is a a whole load of hoop-jumping to make it easier to store additional values in login tokens, and then to actually store the SSO Identity Provider in the login token. (Making use of that data will follow in a subsequent PR.)
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py68
1 files changed, 58 insertions, 10 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 3978e41518..bec0c615d4 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -65,6 +65,7 @@ from synapse.storage.roommember import ProfileInfo
 from synapse.types import JsonDict, Requester, UserID
 from synapse.util import stringutils as stringutils
 from synapse.util.async_helpers import maybe_awaitable
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.threepids import canonicalise_email
 
@@ -170,6 +171,16 @@ class SsoLoginExtraAttributes:
     extra_attributes = attr.ib(type=JsonDict)
 
 
+@attr.s(slots=True, frozen=True)
+class LoginTokenAttributes:
+    """Data we store in a short-term login token"""
+
+    user_id = attr.ib(type=str)
+
+    # the SSO Identity Provider that the user authenticated with, to get this token
+    auth_provider_id = attr.ib(type=str)
+
+
 class AuthHandler(BaseHandler):
     SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
 
@@ -1164,18 +1175,16 @@ class AuthHandler(BaseHandler):
             return None
         return user_id
 
-    async def validate_short_term_login_token_and_get_user_id(self, login_token: str):
-        auth_api = self.hs.get_auth()
-        user_id = None
+    async def validate_short_term_login_token(
+        self, login_token: str
+    ) -> LoginTokenAttributes:
         try:
-            macaroon = pymacaroons.Macaroon.deserialize(login_token)
-            user_id = auth_api.get_user_id_from_macaroon(macaroon)
-            auth_api.validate_macaroon(macaroon, "login", user_id)
+            res = self.macaroon_gen.verify_short_term_login_token(login_token)
         except Exception:
             raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
 
-        await self.auth.check_auth_blocking(user_id)
-        return user_id
+        await self.auth.check_auth_blocking(res.user_id)
+        return res
 
     async def delete_access_token(self, access_token: str):
         """Invalidate a single access token
@@ -1397,6 +1406,7 @@ class AuthHandler(BaseHandler):
     async def complete_sso_login(
         self,
         registered_user_id: str,
+        auth_provider_id: str,
         request: Request,
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
@@ -1406,6 +1416,9 @@ class AuthHandler(BaseHandler):
 
         Args:
             registered_user_id: The registered user ID to complete SSO login for.
+            auth_provider_id: The id of the SSO Identity provider that was used for
+                login. This will be stored in the login token for future tracking in
+                prometheus metrics.
             request: The request to complete.
             client_redirect_url: The URL to which to redirect the user at the end of the
                 process.
@@ -1427,6 +1440,7 @@ class AuthHandler(BaseHandler):
 
         self._complete_sso_login(
             registered_user_id,
+            auth_provider_id,
             request,
             client_redirect_url,
             extra_attributes,
@@ -1437,6 +1451,7 @@ class AuthHandler(BaseHandler):
     def _complete_sso_login(
         self,
         registered_user_id: str,
+        auth_provider_id: str,
         request: Request,
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
@@ -1463,7 +1478,7 @@ class AuthHandler(BaseHandler):
 
         # Create a login token
         login_token = self.macaroon_gen.generate_short_term_login_token(
-            registered_user_id
+            registered_user_id, auth_provider_id=auth_provider_id
         )
 
         # Append the login token to the original redirect URL (i.e. with its query
@@ -1569,15 +1584,48 @@ class MacaroonGenerator:
         return macaroon.serialize()
 
     def generate_short_term_login_token(
-        self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+        self,
+        user_id: str,
+        auth_provider_id: str,
+        duration_in_ms: int = (2 * 60 * 1000),
     ) -> str:
         macaroon = self._generate_base_macaroon(user_id)
         macaroon.add_first_party_caveat("type = login")
         now = self.hs.get_clock().time_msec()
         expiry = now + duration_in_ms
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
+        macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
         return macaroon.serialize()
 
+    def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
+        """Verify a short-term-login macaroon
+
+        Checks that the given token is a valid, unexpired short-term-login token
+        minted by this server.
+
+        Args:
+            token: the login token to verify
+
+        Returns:
+            the user_id that this token is valid for
+
+        Raises:
+            MacaroonVerificationFailedException if the verification failed
+        """
+        macaroon = pymacaroons.Macaroon.deserialize(token)
+        user_id = get_value_from_macaroon(macaroon, "user_id")
+        auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
+
+        v = pymacaroons.Verifier()
+        v.satisfy_exact("gen = 1")
+        v.satisfy_exact("type = login")
+        v.satisfy_general(lambda c: c.startswith("user_id = "))
+        v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
+        satisfy_expiry(v, self.hs.get_clock().time_msec)
+        v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
+
+        return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
+
     def generate_delete_pusher_token(self, user_id: str) -> str:
         macaroon = self._generate_base_macaroon(user_id)
         macaroon.add_first_party_caveat("type = delete_pusher")