1 files changed, 15 insertions, 1 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 0e12880ab5..1e62beaff8 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -45,8 +45,9 @@ class LoginRestServlet(ClientV1RestServlet):
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
-
self.cas_server_url = hs.config.cas_server_url
+ self.cas_required_attribute = hs.config.cas_required_attribute
+ self.cas_required_attribute_value = hs.config.cas_required_attribute_value
self.servername = hs.config.server_name
def on_GET(self, request):
@@ -126,6 +127,19 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def do_cas_login(self, cas_response_body):
(user, attributes) = self.parse_cas_response(cas_response_body)
+
+ if self.cas_required_attribute is not None:
+ # If required attribute was not in CAS Response - Forbidden
+ if self.cas_required_attribute not in attributes:
+ raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
+
+ # Also need to check value
+ if self.cas_required_attribute_value is not None:
+ actualValue = attributes[self.cas_required_attribute]
+ # If required attribute value does not match expected - Forbidden
+ if self.cas_required_attribute_value != actualValue:
+ raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
+
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
|