diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index a99dcaab6f..0e12880ab5 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -125,6 +125,34 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def do_cas_login(self, cas_response_body):
+ (user, attributes) = self.parse_cas_response(cas_response_body)
+ user_id = UserID.create(user, self.hs.hostname).to_string()
+ auth_handler = self.handlers.auth_handler
+ user_exists = yield auth_handler.does_user_exist(user_id)
+ if user_exists:
+ user_id, access_token, refresh_token = (
+ yield auth_handler.login_with_cas_user_id(user_id)
+ )
+ result = {
+ "user_id": user_id, # may have changed
+ "access_token": access_token,
+ "refresh_token": refresh_token,
+ "home_server": self.hs.hostname,
+ }
+
+ else:
+ user_id, access_token = (
+ yield self.handlers.registration_handler.register(localpart=user)
+ )
+ result = {
+ "user_id": user_id, # may have changed
+ "access_token": access_token,
+ "home_server": self.hs.hostname,
+ }
+
+ defer.returnValue((200, result))
+
+ def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
@@ -133,33 +161,17 @@ class LoginRestServlet(ClientV1RestServlet):
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
- user_id = UserID.create(user, self.hs.hostname).to_string()
- auth_handler = self.handlers.auth_handler
- user_exists = yield auth_handler.does_user_exist(user_id)
- if user_exists:
- user_id, access_token, refresh_token = (
- yield auth_handler.login_with_cas_user_id(user_id)
- )
- result = {
- "user_id": user_id, # may have changed
- "access_token": access_token,
- "refresh_token": refresh_token,
- "home_server": self.hs.hostname,
- }
-
- else:
- user_id, access_token = (
- yield self.handlers.registration_handler.register(localpart=user)
- )
- result = {
- "user_id": user_id, # may have changed
- "access_token": access_token,
- "home_server": self.hs.hostname,
- }
-
- defer.returnValue((200, result))
+ if child.tag.endswith("attributes"):
+ attributes = {}
+ for attribute in child:
+ if "}" in attribute.tag:
+ attributes[attribute.tag.split("}")[1]] = attribute.text
+ else:
+ attributes[attribute.tag] = attribute.text
+ if user is None or attributes is None:
+ raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
- raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
+ return (user, attributes)
class LoginFallbackRestServlet(ClientV1RestServlet):
|