summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py34
1 files changed, 31 insertions, 3 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 4d9c4e5834..61607cf2ba 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -39,6 +39,7 @@ import attr
 import bcrypt
 import pymacaroons
 import unpaddedbase64
+from pymacaroons.exceptions import MacaroonVerificationFailedException
 
 from twisted.web.server import Request
 
@@ -182,8 +183,11 @@ class LoginTokenAttributes:
 
     user_id = attr.ib(type=str)
 
-    # the SSO Identity Provider that the user authenticated with, to get this token
     auth_provider_id = attr.ib(type=str)
+    """The SSO Identity Provider that the user authenticated with, to get this token."""
+
+    auth_provider_session_id = attr.ib(type=Optional[str])
+    """The session ID advertised by the SSO Identity Provider."""
 
 
 class AuthHandler:
@@ -1650,6 +1654,7 @@ class AuthHandler:
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
         new_user: bool = False,
+        auth_provider_session_id: Optional[str] = None,
     ) -> None:
         """Having figured out a mxid for this user, complete the HTTP request
 
@@ -1665,6 +1670,7 @@ class AuthHandler:
                 during successful login. Must be JSON serializable.
             new_user: True if we should use wording appropriate to a user who has just
                 registered.
+            auth_provider_session_id: The session ID from the SSO IdP received during login.
         """
         # If the account has been deactivated, do not proceed with the login
         # flow.
@@ -1685,6 +1691,7 @@ class AuthHandler:
             extra_attributes,
             new_user=new_user,
             user_profile_data=profile,
+            auth_provider_session_id=auth_provider_session_id,
         )
 
     def _complete_sso_login(
@@ -1696,6 +1703,7 @@ class AuthHandler:
         extra_attributes: Optional[JsonDict] = None,
         new_user: bool = False,
         user_profile_data: Optional[ProfileInfo] = None,
+        auth_provider_session_id: Optional[str] = None,
     ) -> None:
         """
         The synchronous portion of complete_sso_login.
@@ -1717,7 +1725,9 @@ class AuthHandler:
 
         # Create a login token
         login_token = self.macaroon_gen.generate_short_term_login_token(
-            registered_user_id, auth_provider_id=auth_provider_id
+            registered_user_id,
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
         )
 
         # Append the login token to the original redirect URL (i.e. with its query
@@ -1822,6 +1832,7 @@ class MacaroonGenerator:
         self,
         user_id: str,
         auth_provider_id: str,
+        auth_provider_session_id: Optional[str] = None,
         duration_in_ms: int = (2 * 60 * 1000),
     ) -> str:
         macaroon = self._generate_base_macaroon(user_id)
@@ -1830,6 +1841,10 @@ class MacaroonGenerator:
         expiry = now + duration_in_ms
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
         macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
+        if auth_provider_session_id is not None:
+            macaroon.add_first_party_caveat(
+                "auth_provider_session_id = %s" % (auth_provider_session_id,)
+            )
         return macaroon.serialize()
 
     def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
@@ -1851,15 +1866,28 @@ class MacaroonGenerator:
         user_id = get_value_from_macaroon(macaroon, "user_id")
         auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
 
+        auth_provider_session_id: Optional[str] = None
+        try:
+            auth_provider_session_id = get_value_from_macaroon(
+                macaroon, "auth_provider_session_id"
+            )
+        except MacaroonVerificationFailedException:
+            pass
+
         v = pymacaroons.Verifier()
         v.satisfy_exact("gen = 1")
         v.satisfy_exact("type = login")
         v.satisfy_general(lambda c: c.startswith("user_id = "))
         v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
+        v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = "))
         satisfy_expiry(v, self.hs.get_clock().time_msec)
         v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
 
-        return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
+        return LoginTokenAttributes(
+            user_id=user_id,
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
+        )
 
     def generate_delete_pusher_token(self, user_id: str) -> str:
         macaroon = self._generate_base_macaroon(user_id)