diff options
author | Patrick Cloke <clokep@users.noreply.github.com> | 2020-11-23 13:28:03 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-11-23 13:28:03 -0500 |
commit | 6fde6aa9c02d35e0a908437ea49b275df9b58427 (patch) | |
tree | 759d0114718fc2f8d58cd94a7776690ff6538ad6 /synapse/handlers/cas_handler.py | |
parent | Fix synctl and duplicate worker spawning (#8798) (diff) | |
download | synapse-6fde6aa9c02d35e0a908437ea49b275df9b58427.tar.xz |
Properly report user-agent/IP during registration of SSO users. (#8784)
This also expands type-hints to the SSO and registration code. Refactors the CAS code to more closely match OIDC/SAML.
Diffstat (limited to 'synapse/handlers/cas_handler.py')
-rw-r--r-- | synapse/handlers/cas_handler.py | 71 |
1 files changed, 52 insertions, 19 deletions
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 048a3b3c0b..f4ea0a9767 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import urllib -from typing import Dict, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Optional, Tuple from xml.etree import ElementTree as ET from twisted.web.client import PartialDownloadError @@ -23,6 +23,9 @@ from synapse.api.errors import Codes, LoginError from synapse.http.site import SynapseRequest from synapse.types import UserID, map_username_to_mxid_localpart +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -31,10 +34,10 @@ class CasHandler: Utility class for to handle the response from a CAS SSO service. Args: - hs (synapse.server.HomeServer) + hs """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self._hostname = hs.hostname self._auth_handler = hs.get_auth_handler() @@ -200,27 +203,57 @@ class CasHandler: args["session"] = session username, user_display_name = await self._validate_ticket(ticket, args) - localpart = map_username_to_mxid_localpart(username) - user_id = UserID(localpart, self._hostname).to_string() - registered_user_id = await self._auth_handler.check_user_exists(user_id) + # Pull out the user-agent and IP from the request. + user_agent = request.get_user_agent("") + ip_address = self.hs.get_ip_from_request(request) + + # Get the matrix ID from the CAS username. + user_id = await self._map_cas_user_to_matrix_user( + username, user_display_name, user_agent, ip_address + ) if session: await self._auth_handler.complete_sso_ui_auth( - registered_user_id, session, request, + user_id, session, request, ) - else: - if not registered_user_id: - # Pull out the user-agent and IP from the request. - user_agent = request.get_user_agent("") - ip_address = self.hs.get_ip_from_request(request) - - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, - default_display_name=user_display_name, - user_agent_ips=(user_agent, ip_address), - ) + # If this not a UI auth request than there must be a redirect URL. + assert client_redirect_url await self._auth_handler.complete_sso_login( - registered_user_id, request, client_redirect_url + user_id, request, client_redirect_url ) + + async def _map_cas_user_to_matrix_user( + self, + remote_user_id: str, + display_name: Optional[str], + user_agent: str, + ip_address: str, + ) -> str: + """ + Given a CAS username, retrieve the user ID for it and possibly register the user. + + Args: + remote_user_id: The username from the CAS response. + display_name: The display name from the CAS response. + user_agent: The user agent of the client making the request. + ip_address: The IP address of the client making the request. + + Returns: + The user ID associated with this response. + """ + + localpart = map_username_to_mxid_localpart(remote_user_id) + user_id = UserID(localpart, self._hostname).to_string() + registered_user_id = await self._auth_handler.check_user_exists(user_id) + + # If the user does not exist, register it. + if not registered_user_id: + registered_user_id = await self._registration_handler.register_user( + localpart=localpart, + default_display_name=display_name, + user_agent_ips=[(user_agent, ip_address)], + ) + + return registered_user_id |