summary refs log tree commit diff
path: root/synapse/handlers/oidc_handler.py
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2020-12-08 14:03:38 +0000
committerGitHub <noreply@github.com>2020-12-08 14:03:38 +0000
commit36ba73f53d9919c7639d4c7269fabdb1857fb7a1 (patch)
tree445d0a6c7ea5c2be75dc0e43d1126b9c5168a314 /synapse/handlers/oidc_handler.py
parentClarify config template comments (#8891) (diff)
downloadsynapse-36ba73f53d9919c7639d4c7269fabdb1857fb7a1.tar.xz
Simplify the flow for SSO UIA (#8881)
* SsoHandler: remove inheritance from BaseHandler

* Simplify the flow for SSO UIA

We don't need to do all the magic for mapping users when we are doing UIA, so
let's factor that out.
Diffstat (limited to 'synapse/handlers/oidc_handler.py')
-rw-r--r--synapse/handlers/oidc_handler.py44
1 files changed, 32 insertions, 12 deletions
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index c605f7082a..f626117f76 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -674,6 +674,21 @@ class OidcHandler(BaseHandler):
                 self._sso_handler.render_error(request, "invalid_token", str(e))
                 return
 
+        # first check if we're doing a UIA
+        if ui_auth_session_id:
+            try:
+                remote_user_id = self._remote_id_from_userinfo(userinfo)
+            except Exception as e:
+                logger.exception("Could not extract remote user id")
+                self._sso_handler.render_error(request, "mapping_error", str(e))
+                return
+
+            return await self._sso_handler.complete_sso_ui_auth_request(
+                self._auth_provider_id, remote_user_id, ui_auth_session_id, request
+            )
+
+        # otherwise, it's a login
+
         # Pull out the user-agent and IP from the request.
         user_agent = request.get_user_agent("")
         ip_address = self.hs.get_ip_from_request(request)
@@ -698,14 +713,9 @@ class OidcHandler(BaseHandler):
             extra_attributes = await get_extra_attributes(userinfo, token)
 
         # and finally complete the login
-        if ui_auth_session_id:
-            await self._auth_handler.complete_sso_ui_auth(
-                user_id, ui_auth_session_id, request
-            )
-        else:
-            await self._auth_handler.complete_sso_login(
-                user_id, request, client_redirect_url, extra_attributes
-            )
+        await self._auth_handler.complete_sso_login(
+            user_id, request, client_redirect_url, extra_attributes
+        )
 
     def _generate_oidc_session_token(
         self,
@@ -856,14 +866,11 @@ class OidcHandler(BaseHandler):
             The mxid of the user
         """
         try:
-            remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
+            remote_user_id = self._remote_id_from_userinfo(userinfo)
         except Exception as e:
             raise MappingException(
                 "Failed to extract subject from OIDC response: %s" % (e,)
             )
-        # Some OIDC providers use integer IDs, but Synapse expects external IDs
-        # to be strings.
-        remote_user_id = str(remote_user_id)
 
         # Older mapping providers don't accept the `failures` argument, so we
         # try and detect support.
@@ -933,6 +940,19 @@ class OidcHandler(BaseHandler):
             grandfather_existing_users,
         )
 
+    def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
+        """Extract the unique remote id from an OIDC UserInfo block
+
+        Args:
+            userinfo: An object representing the user given by the OIDC provider
+        Returns:
+            remote user id
+        """
+        remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
+        # Some OIDC providers use integer IDs, but Synapse expects external IDs
+        # to be strings.
+        return str(remote_user_id)
+
 
 UserAttributeDict = TypedDict(
     "UserAttributeDict", {"localpart": str, "display_name": Optional[str]}