summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py60
1 files changed, 59 insertions, 1 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 0322b60cfc..00eae92052 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -137,6 +137,15 @@ def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]:
     }
 
 
+@attr.s(slots=True)
+class SsoLoginExtraAttributes:
+    """Data we track about SAML2 sessions"""
+
+    # time the session was created, in milliseconds
+    creation_time = attr.ib(type=int)
+    extra_attributes = attr.ib(type=JsonDict)
+
+
 class AuthHandler(BaseHandler):
     SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
 
@@ -239,6 +248,10 @@ class AuthHandler(BaseHandler):
         # cast to tuple for use with str.startswith
         self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
 
+        # A mapping of user ID to extra attributes to include in the login
+        # response.
+        self._extra_attributes = {}  # type: Dict[str, SsoLoginExtraAttributes]
+
     async def validate_user_via_ui_auth(
         self,
         requester: Requester,
@@ -1165,6 +1178,7 @@ class AuthHandler(BaseHandler):
         registered_user_id: str,
         request: SynapseRequest,
         client_redirect_url: str,
+        extra_attributes: Optional[JsonDict] = None,
     ):
         """Having figured out a mxid for this user, complete the HTTP request
 
@@ -1173,6 +1187,8 @@ class AuthHandler(BaseHandler):
             request: The request to complete.
             client_redirect_url: The URL to which to redirect the user at the end of the
                 process.
+            extra_attributes: Extra attributes which will be passed to the client
+                during successful login. Must be JSON serializable.
         """
         # If the account has been deactivated, do not proceed with the login
         # flow.
@@ -1181,19 +1197,30 @@ class AuthHandler(BaseHandler):
             respond_with_html(request, 403, self._sso_account_deactivated_template)
             return
 
-        self._complete_sso_login(registered_user_id, request, client_redirect_url)
+        self._complete_sso_login(
+            registered_user_id, request, client_redirect_url, extra_attributes
+        )
 
     def _complete_sso_login(
         self,
         registered_user_id: str,
         request: SynapseRequest,
         client_redirect_url: str,
+        extra_attributes: Optional[JsonDict] = None,
     ):
         """
         The synchronous portion of complete_sso_login.
 
         This exists purely for backwards compatibility of synapse.module_api.ModuleApi.
         """
+        # Store any extra attributes which will be passed in the login response.
+        # Note that this is per-user so it may overwrite a previous value, this
+        # is considered OK since the newest SSO attributes should be most valid.
+        if extra_attributes:
+            self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes(
+                self._clock.time_msec(), extra_attributes,
+            )
+
         # Create a login token
         login_token = self.macaroon_gen.generate_short_term_login_token(
             registered_user_id
@@ -1226,6 +1253,37 @@ class AuthHandler(BaseHandler):
         )
         respond_with_html(request, 200, html)
 
+    async def _sso_login_callback(self, login_result: JsonDict) -> None:
+        """
+        A login callback which might add additional attributes to the login response.
+
+        Args:
+            login_result: The data to be sent to the client. Includes the user
+                ID and access token.
+        """
+        # Expire attributes before processing. Note that there shouldn't be any
+        # valid logins that still have extra attributes.
+        self._expire_sso_extra_attributes()
+
+        extra_attributes = self._extra_attributes.get(login_result["user_id"])
+        if extra_attributes:
+            login_result.update(extra_attributes.extra_attributes)
+
+    def _expire_sso_extra_attributes(self) -> None:
+        """
+        Iterate through the mapping of user IDs to extra attributes and remove any that are no longer valid.
+        """
+        # TODO This should match the amount of time the macaroon is valid for.
+        LOGIN_TOKEN_EXPIRATION_TIME = 2 * 60 * 1000
+        expire_before = self._clock.time_msec() - LOGIN_TOKEN_EXPIRATION_TIME
+        to_expire = set()
+        for user_id, data in self._extra_attributes.items():
+            if data.creation_time < expire_before:
+                to_expire.add(user_id)
+        for user_id in to_expire:
+            logger.debug("Expiring extra attributes for user %s", user_id)
+            del self._extra_attributes[user_id]
+
     @staticmethod
     def add_query_param_to_url(url: str, param_name: str, param: Any):
         url_parts = list(urllib.parse.urlparse(url))