summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py158
1 files changed, 127 insertions, 31 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 60e59d11a0..61607cf2ba 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -18,6 +18,7 @@ import time
 import unicodedata
 import urllib.parse
 from binascii import crc32
+from http import HTTPStatus
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -38,6 +39,7 @@ import attr
 import bcrypt
 import pymacaroons
 import unpaddedbase64
+from pymacaroons.exceptions import MacaroonVerificationFailedException
 
 from twisted.web.server import Request
 
@@ -181,8 +183,11 @@ class LoginTokenAttributes:
 
     user_id = attr.ib(type=str)
 
-    # the SSO Identity Provider that the user authenticated with, to get this token
     auth_provider_id = attr.ib(type=str)
+    """The SSO Identity Provider that the user authenticated with, to get this token."""
+
+    auth_provider_session_id = attr.ib(type=Optional[str])
+    """The session ID advertised by the SSO Identity Provider."""
 
 
 class AuthHandler:
@@ -756,53 +761,109 @@ class AuthHandler:
     async def refresh_token(
         self,
         refresh_token: str,
-        valid_until_ms: Optional[int],
-    ) -> Tuple[str, str]:
+        access_token_valid_until_ms: Optional[int],
+        refresh_token_valid_until_ms: Optional[int],
+    ) -> Tuple[str, str, Optional[int]]:
         """
         Consumes a refresh token and generate both a new access token and a new refresh token from it.
 
         The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token.
 
+        The lifetime of both the access token and refresh token will be capped so that they
+        do not exceed the session's ultimate expiry time, if applicable.
+
         Args:
             refresh_token: The token to consume.
-            valid_until_ms: The expiration timestamp of the new access token.
-
+            access_token_valid_until_ms: The expiration timestamp of the new access token.
+                None if the access token does not expire.
+            refresh_token_valid_until_ms: The expiration timestamp of the new refresh token.
+                None if the refresh token does not expire.
         Returns:
-            A tuple containing the new access token and refresh token
+            A tuple containing:
+              - the new access token
+              - the new refresh token
+              - the actual expiry time of the access token, which may be earlier than
+                `access_token_valid_until_ms`.
         """
 
         # Verify the token signature first before looking up the token
         if not self._verify_refresh_token(refresh_token):
-            raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN)
+            raise SynapseError(
+                HTTPStatus.UNAUTHORIZED, "invalid refresh token", Codes.UNKNOWN_TOKEN
+            )
 
         existing_token = await self.store.lookup_refresh_token(refresh_token)
         if existing_token is None:
-            raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN)
+            raise SynapseError(
+                HTTPStatus.UNAUTHORIZED,
+                "refresh token does not exist",
+                Codes.UNKNOWN_TOKEN,
+            )
 
         if (
             existing_token.has_next_access_token_been_used
             or existing_token.has_next_refresh_token_been_refreshed
         ):
             raise SynapseError(
-                403, "refresh token isn't valid anymore", Codes.FORBIDDEN
+                HTTPStatus.FORBIDDEN,
+                "refresh token isn't valid anymore",
+                Codes.FORBIDDEN,
             )
 
+        now_ms = self._clock.time_msec()
+
+        if existing_token.expiry_ts is not None and existing_token.expiry_ts < now_ms:
+
+            raise SynapseError(
+                HTTPStatus.FORBIDDEN,
+                "The supplied refresh token has expired",
+                Codes.FORBIDDEN,
+            )
+
+        if existing_token.ultimate_session_expiry_ts is not None:
+            # This session has a bounded lifetime, even across refreshes.
+
+            if access_token_valid_until_ms is not None:
+                access_token_valid_until_ms = min(
+                    access_token_valid_until_ms,
+                    existing_token.ultimate_session_expiry_ts,
+                )
+            else:
+                access_token_valid_until_ms = existing_token.ultimate_session_expiry_ts
+
+            if refresh_token_valid_until_ms is not None:
+                refresh_token_valid_until_ms = min(
+                    refresh_token_valid_until_ms,
+                    existing_token.ultimate_session_expiry_ts,
+                )
+            else:
+                refresh_token_valid_until_ms = existing_token.ultimate_session_expiry_ts
+            if existing_token.ultimate_session_expiry_ts < now_ms:
+                raise SynapseError(
+                    HTTPStatus.FORBIDDEN,
+                    "The session has expired and can no longer be refreshed",
+                    Codes.FORBIDDEN,
+                )
+
         (
             new_refresh_token,
             new_refresh_token_id,
-        ) = await self.get_refresh_token_for_user_id(
-            user_id=existing_token.user_id, device_id=existing_token.device_id
+        ) = await self.create_refresh_token_for_user_id(
+            user_id=existing_token.user_id,
+            device_id=existing_token.device_id,
+            expiry_ts=refresh_token_valid_until_ms,
+            ultimate_session_expiry_ts=existing_token.ultimate_session_expiry_ts,
         )
-        access_token = await self.get_access_token_for_user_id(
+        access_token = await self.create_access_token_for_user_id(
             user_id=existing_token.user_id,
             device_id=existing_token.device_id,
-            valid_until_ms=valid_until_ms,
+            valid_until_ms=access_token_valid_until_ms,
             refresh_token_id=new_refresh_token_id,
         )
         await self.store.replace_refresh_token(
             existing_token.token_id, new_refresh_token_id
         )
-        return access_token, new_refresh_token
+        return access_token, new_refresh_token, access_token_valid_until_ms
 
     def _verify_refresh_token(self, token: str) -> bool:
         """
@@ -832,10 +893,12 @@ class AuthHandler:
 
         return True
 
-    async def get_refresh_token_for_user_id(
+    async def create_refresh_token_for_user_id(
         self,
         user_id: str,
         device_id: str,
+        expiry_ts: Optional[int],
+        ultimate_session_expiry_ts: Optional[int],
     ) -> Tuple[str, int]:
         """
         Creates a new refresh token for the user with the given user ID.
@@ -843,6 +906,13 @@ class AuthHandler:
         Args:
             user_id: canonical user ID
             device_id: the device ID to associate with the token.
+            expiry_ts (milliseconds since the epoch): Time after which the
+                refresh token cannot be used.
+                If None, the refresh token never expires until it has been used.
+            ultimate_session_expiry_ts (milliseconds since the epoch):
+                Time at which the session will end and can not be extended any
+                further.
+                If None, the session can be refreshed indefinitely.
 
         Returns:
             The newly created refresh token and its ID in the database
@@ -852,10 +922,12 @@ class AuthHandler:
             user_id=user_id,
             token=refresh_token,
             device_id=device_id,
+            expiry_ts=expiry_ts,
+            ultimate_session_expiry_ts=ultimate_session_expiry_ts,
         )
         return refresh_token, refresh_token_id
 
-    async def get_access_token_for_user_id(
+    async def create_access_token_for_user_id(
         self,
         user_id: str,
         device_id: Optional[str],
@@ -1582,6 +1654,7 @@ class AuthHandler:
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
         new_user: bool = False,
+        auth_provider_session_id: Optional[str] = None,
     ) -> None:
         """Having figured out a mxid for this user, complete the HTTP request
 
@@ -1597,6 +1670,7 @@ class AuthHandler:
                 during successful login. Must be JSON serializable.
             new_user: True if we should use wording appropriate to a user who has just
                 registered.
+            auth_provider_session_id: The session ID from the SSO IdP received during login.
         """
         # If the account has been deactivated, do not proceed with the login
         # flow.
@@ -1617,6 +1691,7 @@ class AuthHandler:
             extra_attributes,
             new_user=new_user,
             user_profile_data=profile,
+            auth_provider_session_id=auth_provider_session_id,
         )
 
     def _complete_sso_login(
@@ -1628,6 +1703,7 @@ class AuthHandler:
         extra_attributes: Optional[JsonDict] = None,
         new_user: bool = False,
         user_profile_data: Optional[ProfileInfo] = None,
+        auth_provider_session_id: Optional[str] = None,
     ) -> None:
         """
         The synchronous portion of complete_sso_login.
@@ -1649,7 +1725,9 @@ class AuthHandler:
 
         # Create a login token
         login_token = self.macaroon_gen.generate_short_term_login_token(
-            registered_user_id, auth_provider_id=auth_provider_id
+            registered_user_id,
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
         )
 
         # Append the login token to the original redirect URL (i.e. with its query
@@ -1754,6 +1832,7 @@ class MacaroonGenerator:
         self,
         user_id: str,
         auth_provider_id: str,
+        auth_provider_session_id: Optional[str] = None,
         duration_in_ms: int = (2 * 60 * 1000),
     ) -> str:
         macaroon = self._generate_base_macaroon(user_id)
@@ -1762,6 +1841,10 @@ class MacaroonGenerator:
         expiry = now + duration_in_ms
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
         macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
+        if auth_provider_session_id is not None:
+            macaroon.add_first_party_caveat(
+                "auth_provider_session_id = %s" % (auth_provider_session_id,)
+            )
         return macaroon.serialize()
 
     def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
@@ -1783,15 +1866,28 @@ class MacaroonGenerator:
         user_id = get_value_from_macaroon(macaroon, "user_id")
         auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
 
+        auth_provider_session_id: Optional[str] = None
+        try:
+            auth_provider_session_id = get_value_from_macaroon(
+                macaroon, "auth_provider_session_id"
+            )
+        except MacaroonVerificationFailedException:
+            pass
+
         v = pymacaroons.Verifier()
         v.satisfy_exact("gen = 1")
         v.satisfy_exact("type = login")
         v.satisfy_general(lambda c: c.startswith("user_id = "))
         v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
+        v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = "))
         satisfy_expiry(v, self.hs.get_clock().time_msec)
         v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
 
-        return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
+        return LoginTokenAttributes(
+            user_id=user_id,
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
+        )
 
     def generate_delete_pusher_token(self, user_id: str) -> str:
         macaroon = self._generate_base_macaroon(user_id)
@@ -1828,13 +1924,6 @@ def load_single_legacy_password_auth_provider(
         logger.error("Error while initializing %r: %s", module, e)
         raise
 
-    # The known hooks. If a module implements a method who's name appears in this set
-    # we'll want to register it
-    password_auth_provider_methods = {
-        "check_3pid_auth",
-        "on_logged_out",
-    }
-
     # All methods that the module provides should be async, but this wasn't enforced
     # in the old module system, so we wrap them if needed
     def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
@@ -1919,11 +2008,14 @@ def load_single_legacy_password_auth_provider(
 
         return run
 
-    # populate hooks with the implemented methods, wrapped with async_wrapper
-    hooks = {
-        hook: async_wrapper(getattr(provider, hook, None))
-        for hook in password_auth_provider_methods
-    }
+    # If the module has these methods implemented, then we pull them out
+    # and register them as hooks.
+    check_3pid_auth_hook: Optional[CHECK_3PID_AUTH_CALLBACK] = async_wrapper(
+        getattr(provider, "check_3pid_auth", None)
+    )
+    on_logged_out_hook: Optional[ON_LOGGED_OUT_CALLBACK] = async_wrapper(
+        getattr(provider, "on_logged_out", None)
+    )
 
     supported_login_types = {}
     # call get_supported_login_types and add that to the dict
@@ -1950,7 +2042,11 @@ def load_single_legacy_password_auth_provider(
         # need to use a tuple here for ("password",) not a list since lists aren't hashable
         auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password
 
-    api.register_password_auth_provider_callbacks(hooks, auth_checkers=auth_checkers)
+    api.register_password_auth_provider_callbacks(
+        check_3pid_auth=check_3pid_auth_hook,
+        on_logged_out=on_logged_out_hook,
+        auth_checkers=auth_checkers,
+    )
 
 
 CHECK_3PID_AUTH_CALLBACK = Callable[