diff options
Diffstat (limited to 'synapse/rest')
-rw-r--r-- | synapse/rest/client/v1/login.py | 13 |
1 files changed, 6 insertions, 7 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 1e62beaff8..84774e61aa 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -46,8 +46,7 @@ class LoginRestServlet(ClientV1RestServlet): self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled self.cas_server_url = hs.config.cas_server_url - self.cas_required_attribute = hs.config.cas_required_attribute - self.cas_required_attribute_value = hs.config.cas_required_attribute_value + self.cas_required_attributes = hs.config.cas_required_attributes self.servername = hs.config.server_name def on_GET(self, request): @@ -128,16 +127,16 @@ class LoginRestServlet(ClientV1RestServlet): def do_cas_login(self, cas_response_body): (user, attributes) = self.parse_cas_response(cas_response_body) - if self.cas_required_attribute is not None: + for required_attribute in self.cas_required_attributes: # If required attribute was not in CAS Response - Forbidden - if self.cas_required_attribute not in attributes: + if required_attribute not in attributes: raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) # Also need to check value - if self.cas_required_attribute_value is not None: - actualValue = attributes[self.cas_required_attribute] + if self.cas_required_attributes[required_attribute] is not None: + actualValue = attributes[required_attribute] # If required attribute value does not match expected - Forbidden - if self.cas_required_attribute_value != actualValue: + if self.cas_required_attributes[required_attribute] != actualValue: raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) user_id = UserID.create(user, self.hs.hostname).to_string() |