diff --git a/changelog.d/8958.misc b/changelog.d/8958.misc
new file mode 100644
index 0000000000..1507073e4f
--- /dev/null
+++ b/changelog.d/8958.misc
@@ -0,0 +1 @@
+Properly store the mapping of external ID to Matrix ID for CAS users.
diff --git a/docs/dev/cas.md b/docs/dev/cas.md
index f8d02cc82c..592b2d8d4f 100644
--- a/docs/dev/cas.md
+++ b/docs/dev/cas.md
@@ -31,7 +31,7 @@ easy to run CAS implementation built on top of Django.
You should now have a Django project configured to serve CAS authentication with
a single user created.
-## Configure Synapse (and Riot) to use CAS
+## Configure Synapse (and Element) to use CAS
1. Modify your `homeserver.yaml` to enable CAS and point it to your locally
running Django test server:
@@ -51,9 +51,9 @@ and that the CAS server is on port 8000, both on localhost.
## Testing the configuration
-Then in Riot:
+Then in Element:
-1. Visit the login page with a Riot pointing at your homeserver.
+1. Visit the login page with a Element pointing at your homeserver.
2. Click the Single Sign-On button.
3. Login using the credentials created with `createsuperuser`.
4. You should be logged in.
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index f4ea0a9767..e9891e1316 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import urllib
-from typing import TYPE_CHECKING, Dict, Optional, Tuple
+import urllib.parse
+from typing import TYPE_CHECKING, Dict, Optional
from xml.etree import ElementTree as ET
+import attr
+
from twisted.web.client import PartialDownloadError
-from synapse.api.errors import Codes, LoginError
+from synapse.api.errors import HttpResponseException
from synapse.http.site import SynapseRequest
from synapse.types import UserID, map_username_to_mxid_localpart
@@ -29,6 +31,26 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+class CasError(Exception):
+ """Used to catch errors when validating the CAS ticket.
+ """
+
+ def __init__(self, error, error_description=None):
+ self.error = error
+ self.error_description = error_description
+
+ def __str__(self):
+ if self.error_description:
+ return "{}: {}".format(self.error, self.error_description)
+ return self.error
+
+
+@attr.s(slots=True, frozen=True)
+class CasResponse:
+ username = attr.ib(type=str)
+ attributes = attr.ib(type=Dict[str, Optional[str]])
+
+
class CasHandler:
"""
Utility class for to handle the response from a CAS SSO service.
@@ -50,6 +72,8 @@ class CasHandler:
self._http_client = hs.get_proxied_http_client()
+ self._sso_handler = hs.get_sso_handler()
+
def _build_service_param(self, args: Dict[str, str]) -> str:
"""
Generates a value to use as the "service" parameter when redirecting or
@@ -69,14 +93,20 @@ class CasHandler:
async def _validate_ticket(
self, ticket: str, service_args: Dict[str, str]
- ) -> Tuple[str, Optional[str]]:
+ ) -> CasResponse:
"""
- Validate a CAS ticket with the server, parse the response, and return the user and display name.
+ Validate a CAS ticket with the server, and return the parsed the response.
Args:
ticket: The CAS ticket from the client.
service_args: Additional arguments to include in the service URL.
Should be the same as those passed to `get_redirect_url`.
+
+ Raises:
+ CasError: If there's an error parsing the CAS response.
+
+ Returns:
+ The parsed CAS response.
"""
uri = self._cas_server_url + "/proxyValidate"
args = {
@@ -89,66 +119,65 @@ class CasHandler:
# Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data
body = pde.response
+ except HttpResponseException as e:
+ description = (
+ (
+ 'Authorization server responded with a "{status}" error '
+ "while exchanging the authorization code."
+ ).format(status=e.code),
+ )
+ raise CasError("server_error", description) from e
- user, attributes = self._parse_cas_response(body)
- displayname = attributes.pop(self._cas_displayname_attribute, None)
-
- for required_attribute, required_value in self._cas_required_attributes.items():
- # If required attribute was not in CAS Response - Forbidden
- if required_attribute not in attributes:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
- # Also need to check value
- if required_value is not None:
- actual_value = attributes[required_attribute]
- # If required attribute value does not match expected - Forbidden
- if required_value != actual_value:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
- return user, displayname
+ return self._parse_cas_response(body)
- def _parse_cas_response(
- self, cas_response_body: bytes
- ) -> Tuple[str, Dict[str, Optional[str]]]:
+ def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:
"""
Retrieve the user and other parameters from the CAS response.
Args:
cas_response_body: The response from the CAS query.
+ Raises:
+ CasError: If there's an error parsing the CAS response.
+
Returns:
- A tuple of the user and a mapping of other attributes.
+ The parsed CAS response.
"""
+
+ # Ensure the response is valid.
+ root = ET.fromstring(cas_response_body)
+ if not root.tag.endswith("serviceResponse"):
+ raise CasError(
+ "missing_service_response",
+ "root of CAS response is not serviceResponse",
+ )
+
+ success = root[0].tag.endswith("authenticationSuccess")
+ if not success:
+ raise CasError("unsucessful_response", "Unsuccessful CAS response")
+
+ # Iterate through the nodes and pull out the user and any extra attributes.
user = None
attributes = {}
- try:
- root = ET.fromstring(cas_response_body)
- if not root.tag.endswith("serviceResponse"):
- raise Exception("root of CAS response is not serviceResponse")
- success = root[0].tag.endswith("authenticationSuccess")
- for child in root[0]:
- if child.tag.endswith("user"):
- user = child.text
- if child.tag.endswith("attributes"):
- for attribute in child:
- # ElementTree library expands the namespace in
- # attribute tags to the full URL of the namespace.
- # We don't care about namespace here and it will always
- # be encased in curly braces, so we remove them.
- tag = attribute.tag
- if "}" in tag:
- tag = tag.split("}")[1]
- attributes[tag] = attribute.text
- if user is None:
- raise Exception("CAS response does not contain user")
- except Exception:
- logger.exception("Error parsing CAS response")
- raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
- if not success:
- raise LoginError(
- 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
- )
- return user, attributes
+ for child in root[0]:
+ if child.tag.endswith("user"):
+ user = child.text
+ if child.tag.endswith("attributes"):
+ for attribute in child:
+ # ElementTree library expands the namespace in
+ # attribute tags to the full URL of the namespace.
+ # We don't care about namespace here and it will always
+ # be encased in curly braces, so we remove them.
+ tag = attribute.tag
+ if "}" in tag:
+ tag = tag.split("}")[1]
+ attributes[tag] = attribute.text
+
+ # Ensure a user was found.
+ if user is None:
+ raise CasError("no_user", "CAS response does not contain user")
+
+ return CasResponse(user, attributes)
def get_redirect_url(self, service_args: Dict[str, str]) -> str:
"""
@@ -201,7 +230,68 @@ class CasHandler:
args["redirectUrl"] = client_redirect_url
if session:
args["session"] = session
- username, user_display_name = await self._validate_ticket(ticket, args)
+
+ try:
+ cas_response = await self._validate_ticket(ticket, args)
+ except CasError as e:
+ logger.exception("Could not validate ticket")
+ self._sso_handler.render_error(request, e.error, e.error_description, 401)
+ return
+
+ await self._handle_cas_response(
+ request, cas_response, client_redirect_url, session
+ )
+
+ async def _handle_cas_response(
+ self,
+ request: SynapseRequest,
+ cas_response: CasResponse,
+ client_redirect_url: Optional[str],
+ session: Optional[str],
+ ) -> None:
+ """Handle a CAS response to a ticket request.
+
+ Assumes that the response has been validated. Maps the user onto an MXID,
+ registering them if necessary, and returns a response to the browser.
+
+ Args:
+ request: the incoming request from the browser. We'll respond to it with an
+ HTML page or a redirect
+
+ cas_response: The parsed CAS response.
+
+ client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given.
+ This should be the same as the redirectUrl from the original `/login/sso/redirect` request.
+
+ session: The session parameter from the `/cas/ticket` HTTP request, if given.
+ This should be the UI Auth session id.
+ """
+
+ # Ensure that the attributes of the logged in user meet the required
+ # attributes.
+ for required_attribute, required_value in self._cas_required_attributes.items():
+ # If required attribute was not in CAS Response - Forbidden
+ if required_attribute not in cas_response.attributes:
+ self._sso_handler.render_error(
+ request,
+ "unauthorised",
+ "You are not authorised to log in here.",
+ 401,
+ )
+ return
+
+ # Also need to check value
+ if required_value is not None:
+ actual_value = cas_response.attributes[required_attribute]
+ # If required attribute value does not match expected - Forbidden
+ if required_value != actual_value:
+ self._sso_handler.render_error(
+ request,
+ "unauthorised",
+ "You are not authorised to log in here.",
+ 401,
+ )
+ return
# Pull out the user-agent and IP from the request.
user_agent = request.get_user_agent("")
@@ -209,7 +299,7 @@ class CasHandler:
# Get the matrix ID from the CAS username.
user_id = await self._map_cas_user_to_matrix_user(
- username, user_display_name, user_agent, ip_address
+ cas_response, user_agent, ip_address
)
if session:
@@ -225,18 +315,13 @@ class CasHandler:
)
async def _map_cas_user_to_matrix_user(
- self,
- remote_user_id: str,
- display_name: Optional[str],
- user_agent: str,
- ip_address: str,
+ self, cas_response: CasResponse, user_agent: str, ip_address: str,
) -> str:
"""
Given a CAS username, retrieve the user ID for it and possibly register the user.
Args:
- remote_user_id: The username from the CAS response.
- display_name: The display name from the CAS response.
+ cas_response: The parsed CAS response.
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
@@ -244,15 +329,17 @@ class CasHandler:
The user ID associated with this response.
"""
- localpart = map_username_to_mxid_localpart(remote_user_id)
+ localpart = map_username_to_mxid_localpart(cas_response.username)
user_id = UserID(localpart, self._hostname).to_string()
registered_user_id = await self._auth_handler.check_user_exists(user_id)
+ displayname = cas_response.attributes.get(self._cas_displayname_attribute, None)
+
# If the user does not exist, register it.
if not registered_user_id:
registered_user_id = await self._registration_handler.register_user(
localpart=localpart,
- default_display_name=display_name,
+ default_display_name=displayname,
user_agent_ips=[(user_agent, ip_address)],
)
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 548b02211b..b0a8c8c7d2 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -101,7 +101,11 @@ class SsoHandler:
self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession]
def render_error(
- self, request, error: str, error_description: Optional[str] = None
+ self,
+ request: Request,
+ error: str,
+ error_description: Optional[str] = None,
+ code: int = 400,
) -> None:
"""Renders the error template and responds with it.
@@ -113,11 +117,12 @@ class SsoHandler:
We'll respond with an HTML page describing the error.
error: A technical identifier for this error.
error_description: A human-readable description of the error.
+ code: The integer error code (an HTTP response code)
"""
html = self._error_template.render(
error=error, error_description=error_description
)
- respond_with_html(request, 400, html)
+ respond_with_html(request, code, html)
async def get_sso_user_by_remote_user_id(
self, auth_provider_id: str, remote_user_id: str
|