summary refs log tree commit diff
path: root/synapse/rest/client/v1/login.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/v1/login.py')
-rw-r--r--synapse/rest/client/v1/login.py13
1 files changed, 6 insertions, 7 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 1e62beaff8..84774e61aa 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -46,8 +46,7 @@ class LoginRestServlet(ClientV1RestServlet):
         self.saml2_enabled = hs.config.saml2_enabled
         self.cas_enabled = hs.config.cas_enabled
         self.cas_server_url = hs.config.cas_server_url
-        self.cas_required_attribute = hs.config.cas_required_attribute
-        self.cas_required_attribute_value = hs.config.cas_required_attribute_value
+        self.cas_required_attributes = hs.config.cas_required_attributes
         self.servername = hs.config.server_name
 
     def on_GET(self, request):
@@ -128,16 +127,16 @@ class LoginRestServlet(ClientV1RestServlet):
     def do_cas_login(self, cas_response_body):
         (user, attributes) = self.parse_cas_response(cas_response_body)
 
-        if self.cas_required_attribute is not None:
+        for required_attribute in self.cas_required_attributes:
             # If required attribute was not in CAS Response - Forbidden
-            if self.cas_required_attribute not in attributes:
+            if required_attribute not in attributes:
                 raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
 
             # Also need to check value
-            if self.cas_required_attribute_value is not None:
-                actualValue = attributes[self.cas_required_attribute]
+            if self.cas_required_attributes[required_attribute] is not None:
+                actualValue = attributes[required_attribute]
                 # If required attribute value does not match expected - Forbidden
-                if self.cas_required_attribute_value != actualValue:
+                if self.cas_required_attributes[required_attribute] != actualValue:
                     raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
 
         user_id = UserID.create(user, self.hs.hostname).to_string()