summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/cas_handler.py112
-rw-r--r--synapse/handlers/sso.py4
2 files changed, 77 insertions, 39 deletions
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index e9891e1316..fca210a5a6 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -22,6 +22,7 @@ import attr
 from twisted.web.client import PartialDownloadError
 
 from synapse.api.errors import HttpResponseException
+from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.site import SynapseRequest
 from synapse.types import UserID, map_username_to_mxid_localpart
 
@@ -62,6 +63,7 @@ class CasHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self._hostname = hs.hostname
+        self._store = hs.get_datastore()
         self._auth_handler = hs.get_auth_handler()
         self._registration_handler = hs.get_registration_handler()
 
@@ -72,6 +74,9 @@ class CasHandler:
 
         self._http_client = hs.get_proxied_http_client()
 
+        # identifier for the external_ids table
+        self._auth_provider_id = "cas"
+
         self._sso_handler = hs.get_sso_handler()
 
     def _build_service_param(self, args: Dict[str, str]) -> str:
@@ -267,6 +272,14 @@ class CasHandler:
                 This should be the UI Auth session id.
         """
 
+        # first check if we're doing a UIA
+        if session:
+            return await self._sso_handler.complete_sso_ui_auth_request(
+                self._auth_provider_id, cas_response.username, session, request,
+            )
+
+        # otherwise, we're handling a login request.
+
         # Ensure that the attributes of the logged in user meet the required
         # attributes.
         for required_attribute, required_value in self._cas_required_attributes.items():
@@ -293,54 +306,79 @@ class CasHandler:
                     )
                     return
 
-        # Pull out the user-agent and IP from the request.
-        user_agent = request.get_user_agent("")
-        ip_address = self.hs.get_ip_from_request(request)
-
-        # Get the matrix ID from the CAS username.
-        user_id = await self._map_cas_user_to_matrix_user(
-            cas_response, user_agent, ip_address
-        )
+        # Call the mapper to register/login the user
 
-        if session:
-            await self._auth_handler.complete_sso_ui_auth(
-                user_id, session, request,
-            )
-        else:
-            # If this not a UI auth request than there must be a redirect URL.
-            assert client_redirect_url
+        # If this not a UI auth request than there must be a redirect URL.
+        assert client_redirect_url is not None
 
-            await self._auth_handler.complete_sso_login(
-                user_id, request, client_redirect_url
-            )
+        try:
+            await self._complete_cas_login(cas_response, request, client_redirect_url)
+        except MappingException as e:
+            logger.exception("Could not map user")
+            self._sso_handler.render_error(request, "mapping_error", str(e))
 
-    async def _map_cas_user_to_matrix_user(
-        self, cas_response: CasResponse, user_agent: str, ip_address: str,
-    ) -> str:
+    async def _complete_cas_login(
+        self,
+        cas_response: CasResponse,
+        request: SynapseRequest,
+        client_redirect_url: str,
+    ) -> None:
         """
-        Given a CAS username, retrieve the user ID for it and possibly register the user.
+        Given a CAS response, complete the login flow
+
+        Retrieves the remote user ID, registers the user if necessary, and serves
+        a redirect back to the client with a login-token.
 
         Args:
             cas_response: The parsed CAS response.
-            user_agent: The user agent of the client making the request.
-            ip_address: The IP address of the client making the request.
+            request: The request to respond to
+            client_redirect_url: The redirect URL passed in by the client.
 
-        Returns:
-             The user ID associated with this response.
+        Raises:
+            MappingException if there was a problem mapping the response to a user.
+            RedirectException: some mapping providers may raise this if they need
+                to redirect to an interstitial page.
         """
-
+        # Note that CAS does not support a mapping provider, so the logic is hard-coded.
         localpart = map_username_to_mxid_localpart(cas_response.username)
-        user_id = UserID(localpart, self._hostname).to_string()
-        registered_user_id = await self._auth_handler.check_user_exists(user_id)
 
-        displayname = cas_response.attributes.get(self._cas_displayname_attribute, None)
+        async def cas_response_to_user_attributes(failures: int) -> UserAttributes:
+            """
+            Map from CAS attributes to user attributes.
+            """
+            # Due to the grandfathering logic matching any previously registered
+            # mxids it isn't expected for there to be any failures.
+            if failures:
+                raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
+
+            display_name = cas_response.attributes.get(
+                self._cas_displayname_attribute, None
+            )
+
+            return UserAttributes(localpart=localpart, display_name=display_name)
 
-        # If the user does not exist, register it.
-        if not registered_user_id:
-            registered_user_id = await self._registration_handler.register_user(
-                localpart=localpart,
-                default_display_name=displayname,
-                user_agent_ips=[(user_agent, ip_address)],
+        async def grandfather_existing_users() -> Optional[str]:
+            # Since CAS did not always use the user_external_ids table, always
+            # to attempt to map to existing users.
+            user_id = UserID(localpart, self._hostname).to_string()
+
+            logger.debug(
+                "Looking for existing account based on mapped %s", user_id,
             )
 
-        return registered_user_id
+            users = await self._store.get_users_by_id_case_insensitive(user_id)
+            if users:
+                registered_user_id = list(users.keys())[0]
+                logger.info("Grandfathering mapping to %s", registered_user_id)
+                return registered_user_id
+
+            return None
+
+        await self._sso_handler.complete_sso_login_request(
+            self._auth_provider_id,
+            cas_response.username,
+            request,
+            client_redirect_url,
+            cas_response_to_user_attributes,
+            grandfather_existing_users,
+        )
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index b0a8c8c7d2..33cd6bc178 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -173,7 +173,7 @@ class SsoHandler:
         request: SynapseRequest,
         client_redirect_url: str,
         sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
-        grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
+        grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
         extra_login_attributes: Optional[JsonDict] = None,
     ) -> None:
         """
@@ -241,7 +241,7 @@ class SsoHandler:
             )
 
             # Check for grandfathering of users.
-            if not user_id and grandfather_existing_users:
+            if not user_id:
                 user_id = await grandfather_existing_users()
                 if user_id:
                     # Future logins should also match this user ID.