summary refs log tree commit diff
path: root/synapse/handlers/cas_handler.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/cas_handler.py')
-rw-r--r--synapse/handlers/cas_handler.py71
1 files changed, 52 insertions, 19 deletions
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 048a3b3c0b..f4ea0a9767 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import logging
 import urllib
-from typing import Dict, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Optional, Tuple
 from xml.etree import ElementTree as ET
 
 from twisted.web.client import PartialDownloadError
@@ -23,6 +23,9 @@ from synapse.api.errors import Codes, LoginError
 from synapse.http.site import SynapseRequest
 from synapse.types import UserID, map_username_to_mxid_localpart
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -31,10 +34,10 @@ class CasHandler:
     Utility class for to handle the response from a CAS SSO service.
 
     Args:
-        hs (synapse.server.HomeServer)
+        hs
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self._hostname = hs.hostname
         self._auth_handler = hs.get_auth_handler()
@@ -200,27 +203,57 @@ class CasHandler:
             args["session"] = session
         username, user_display_name = await self._validate_ticket(ticket, args)
 
-        localpart = map_username_to_mxid_localpart(username)
-        user_id = UserID(localpart, self._hostname).to_string()
-        registered_user_id = await self._auth_handler.check_user_exists(user_id)
+        # Pull out the user-agent and IP from the request.
+        user_agent = request.get_user_agent("")
+        ip_address = self.hs.get_ip_from_request(request)
+
+        # Get the matrix ID from the CAS username.
+        user_id = await self._map_cas_user_to_matrix_user(
+            username, user_display_name, user_agent, ip_address
+        )
 
         if session:
             await self._auth_handler.complete_sso_ui_auth(
-                registered_user_id, session, request,
+                user_id, session, request,
             )
-
         else:
-            if not registered_user_id:
-                # Pull out the user-agent and IP from the request.
-                user_agent = request.get_user_agent("")
-                ip_address = self.hs.get_ip_from_request(request)
-
-                registered_user_id = await self._registration_handler.register_user(
-                    localpart=localpart,
-                    default_display_name=user_display_name,
-                    user_agent_ips=(user_agent, ip_address),
-                )
+            # If this not a UI auth request than there must be a redirect URL.
+            assert client_redirect_url
 
             await self._auth_handler.complete_sso_login(
-                registered_user_id, request, client_redirect_url
+                user_id, request, client_redirect_url
             )
+
+    async def _map_cas_user_to_matrix_user(
+        self,
+        remote_user_id: str,
+        display_name: Optional[str],
+        user_agent: str,
+        ip_address: str,
+    ) -> str:
+        """
+        Given a CAS username, retrieve the user ID for it and possibly register the user.
+
+        Args:
+            remote_user_id: The username from the CAS response.
+            display_name: The display name from the CAS response.
+            user_agent: The user agent of the client making the request.
+            ip_address: The IP address of the client making the request.
+
+        Returns:
+             The user ID associated with this response.
+        """
+
+        localpart = map_username_to_mxid_localpart(remote_user_id)
+        user_id = UserID(localpart, self._hostname).to_string()
+        registered_user_id = await self._auth_handler.check_user_exists(user_id)
+
+        # If the user does not exist, register it.
+        if not registered_user_id:
+            registered_user_id = await self._registration_handler.register_user(
+                localpart=localpart,
+                default_display_name=display_name,
+                user_agent_ips=[(user_agent, ip_address)],
+            )
+
+        return registered_user_id