diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index e9891e1316..fca210a5a6 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -22,6 +22,7 @@ import attr
from twisted.web.client import PartialDownloadError
from synapse.api.errors import HttpResponseException
+from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.types import UserID, map_username_to_mxid_localpart
@@ -62,6 +63,7 @@ class CasHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self._hostname = hs.hostname
+ self._store = hs.get_datastore()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
@@ -72,6 +74,9 @@ class CasHandler:
self._http_client = hs.get_proxied_http_client()
+ # identifier for the external_ids table
+ self._auth_provider_id = "cas"
+
self._sso_handler = hs.get_sso_handler()
def _build_service_param(self, args: Dict[str, str]) -> str:
@@ -267,6 +272,14 @@ class CasHandler:
This should be the UI Auth session id.
"""
+ # first check if we're doing a UIA
+ if session:
+ return await self._sso_handler.complete_sso_ui_auth_request(
+ self._auth_provider_id, cas_response.username, session, request,
+ )
+
+ # otherwise, we're handling a login request.
+
# Ensure that the attributes of the logged in user meet the required
# attributes.
for required_attribute, required_value in self._cas_required_attributes.items():
@@ -293,54 +306,79 @@ class CasHandler:
)
return
- # Pull out the user-agent and IP from the request.
- user_agent = request.get_user_agent("")
- ip_address = self.hs.get_ip_from_request(request)
-
- # Get the matrix ID from the CAS username.
- user_id = await self._map_cas_user_to_matrix_user(
- cas_response, user_agent, ip_address
- )
+ # Call the mapper to register/login the user
- if session:
- await self._auth_handler.complete_sso_ui_auth(
- user_id, session, request,
- )
- else:
- # If this not a UI auth request than there must be a redirect URL.
- assert client_redirect_url
+ # If this not a UI auth request than there must be a redirect URL.
+ assert client_redirect_url is not None
- await self._auth_handler.complete_sso_login(
- user_id, request, client_redirect_url
- )
+ try:
+ await self._complete_cas_login(cas_response, request, client_redirect_url)
+ except MappingException as e:
+ logger.exception("Could not map user")
+ self._sso_handler.render_error(request, "mapping_error", str(e))
- async def _map_cas_user_to_matrix_user(
- self, cas_response: CasResponse, user_agent: str, ip_address: str,
- ) -> str:
+ async def _complete_cas_login(
+ self,
+ cas_response: CasResponse,
+ request: SynapseRequest,
+ client_redirect_url: str,
+ ) -> None:
"""
- Given a CAS username, retrieve the user ID for it and possibly register the user.
+ Given a CAS response, complete the login flow
+
+ Retrieves the remote user ID, registers the user if necessary, and serves
+ a redirect back to the client with a login-token.
Args:
cas_response: The parsed CAS response.
- user_agent: The user agent of the client making the request.
- ip_address: The IP address of the client making the request.
+ request: The request to respond to
+ client_redirect_url: The redirect URL passed in by the client.
- Returns:
- The user ID associated with this response.
+ Raises:
+ MappingException if there was a problem mapping the response to a user.
+ RedirectException: some mapping providers may raise this if they need
+ to redirect to an interstitial page.
"""
-
+ # Note that CAS does not support a mapping provider, so the logic is hard-coded.
localpart = map_username_to_mxid_localpart(cas_response.username)
- user_id = UserID(localpart, self._hostname).to_string()
- registered_user_id = await self._auth_handler.check_user_exists(user_id)
- displayname = cas_response.attributes.get(self._cas_displayname_attribute, None)
+ async def cas_response_to_user_attributes(failures: int) -> UserAttributes:
+ """
+ Map from CAS attributes to user attributes.
+ """
+ # Due to the grandfathering logic matching any previously registered
+ # mxids it isn't expected for there to be any failures.
+ if failures:
+ raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
+
+ display_name = cas_response.attributes.get(
+ self._cas_displayname_attribute, None
+ )
+
+ return UserAttributes(localpart=localpart, display_name=display_name)
- # If the user does not exist, register it.
- if not registered_user_id:
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=displayname,
- user_agent_ips=[(user_agent, ip_address)],
+ async def grandfather_existing_users() -> Optional[str]:
+ # Since CAS did not always use the user_external_ids table, always
+ # to attempt to map to existing users.
+ user_id = UserID(localpart, self._hostname).to_string()
+
+ logger.debug(
+ "Looking for existing account based on mapped %s", user_id,
)
- return registered_user_id
+ users = await self._store.get_users_by_id_case_insensitive(user_id)
+ if users:
+ registered_user_id = list(users.keys())[0]
+ logger.info("Grandfathering mapping to %s", registered_user_id)
+ return registered_user_id
+
+ return None
+
+ await self._sso_handler.complete_sso_login_request(
+ self._auth_provider_id,
+ cas_response.username,
+ request,
+ client_redirect_url,
+ cas_response_to_user_attributes,
+ grandfather_existing_users,
+ )
|