diff options
author | Ben Banfield-Zanin <benbz@matrix.org> | 2020-12-16 14:49:53 +0000 |
---|---|---|
committer | Ben Banfield-Zanin <benbz@matrix.org> | 2020-12-16 14:49:53 +0000 |
commit | 0825299cfcf61079f78b7a6c5e31f5df078c291a (patch) | |
tree | 5f469584845d065c79f1f6ed4781d0624e87f4d3 /synapse/handlers/cas_handler.py | |
parent | Merge remote-tracking branch 'origin/release-v1.21.2' into bbz/info-mainline-... (diff) | |
parent | Add 'xmlsec1' to dependency list (diff) | |
download | synapse-0825299cfcf61079f78b7a6c5e31f5df078c291a.tar.xz |
Merge remote-tracking branch 'origin/release-v1.24.0' into bbz/info-mainline-1.24.0 github/bbz/info-mainline-1.24.0 bbz/info-mainline-1.24.0
Diffstat (limited to 'synapse/handlers/cas_handler.py')
-rw-r--r-- | synapse/handlers/cas_handler.py | 73 |
1 files changed, 52 insertions, 21 deletions
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index a4cc4b9a5a..f4ea0a9767 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import urllib -from typing import Dict, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Optional, Tuple from xml.etree import ElementTree as ET from twisted.web.client import PartialDownloadError @@ -23,6 +23,9 @@ from synapse.api.errors import Codes, LoginError from synapse.http.site import SynapseRequest from synapse.types import UserID, map_username_to_mxid_localpart +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -31,10 +34,10 @@ class CasHandler: Utility class for to handle the response from a CAS SSO service. Args: - hs (synapse.server.HomeServer) + hs """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self._hostname = hs.hostname self._auth_handler = hs.get_auth_handler() @@ -200,29 +203,57 @@ class CasHandler: args["session"] = session username, user_display_name = await self._validate_ticket(ticket, args) - localpart = map_username_to_mxid_localpart(username) - user_id = UserID(localpart, self._hostname).to_string() - registered_user_id = await self._auth_handler.check_user_exists(user_id) + # Pull out the user-agent and IP from the request. + user_agent = request.get_user_agent("") + ip_address = self.hs.get_ip_from_request(request) + + # Get the matrix ID from the CAS username. + user_id = await self._map_cas_user_to_matrix_user( + username, user_display_name, user_agent, ip_address + ) if session: await self._auth_handler.complete_sso_ui_auth( - registered_user_id, session, request, + user_id, session, request, ) - else: - if not registered_user_id: - # Pull out the user-agent and IP from the request. - user_agent = request.requestHeaders.getRawHeaders( - b"User-Agent", default=[b""] - )[0].decode("ascii", "surrogateescape") - ip_address = self.hs.get_ip_from_request(request) - - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, - default_display_name=user_display_name, - user_agent_ips=(user_agent, ip_address), - ) + # If this not a UI auth request than there must be a redirect URL. + assert client_redirect_url await self._auth_handler.complete_sso_login( - registered_user_id, request, client_redirect_url + user_id, request, client_redirect_url + ) + + async def _map_cas_user_to_matrix_user( + self, + remote_user_id: str, + display_name: Optional[str], + user_agent: str, + ip_address: str, + ) -> str: + """ + Given a CAS username, retrieve the user ID for it and possibly register the user. + + Args: + remote_user_id: The username from the CAS response. + display_name: The display name from the CAS response. + user_agent: The user agent of the client making the request. + ip_address: The IP address of the client making the request. + + Returns: + The user ID associated with this response. + """ + + localpart = map_username_to_mxid_localpart(remote_user_id) + user_id = UserID(localpart, self._hostname).to_string() + registered_user_id = await self._auth_handler.check_user_exists(user_id) + + # If the user does not exist, register it. + if not registered_user_id: + registered_user_id = await self._registration_handler.register_user( + localpart=localpart, + default_display_name=display_name, + user_agent_ips=[(user_agent, ip_address)], ) + + return registered_user_id |