Support multiple required attributes in CAS response, and in a nicer config format too
2 files changed, 10 insertions, 22 deletions
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index 4d1dd8cc7b..e884d03fe6 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -27,28 +27,17 @@ class CasConfig(Config):
if cas_config:
self.cas_enabled = True
self.cas_server_url = cas_config["server_url"]
-
- if "required_attribute" in cas_config:
- self.cas_required_attribute = cas_config["required_attribute"]
- else:
- self.cas_required_attribute = None
-
- if "required_attribute_value" in cas_config:
- self.cas_required_attribute_value = cas_config["required_attribute_value"]
- else:
- self.cas_required_attribute_value = None
-
+ self.cas_required_attributes = cas_config.get("required_attributes", None)
else:
self.cas_enabled = False
self.cas_server_url = None
- self.cas_required_attribute = None
- self.cas_required_attribute_value = None
+ self.cas_required_attributes = {}
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Enable CAS for registration and login.
#cas_config:
# server_url: "https://cas-server.com"
- # #required_attribute: something
- # #required_attribute_value: true
+ # #required_attributes:
+ # # name: value
"""
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 1e62beaff8..84774e61aa 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -46,8 +46,7 @@ class LoginRestServlet(ClientV1RestServlet):
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
self.cas_server_url = hs.config.cas_server_url
- self.cas_required_attribute = hs.config.cas_required_attribute
- self.cas_required_attribute_value = hs.config.cas_required_attribute_value
+ self.cas_required_attributes = hs.config.cas_required_attributes
self.servername = hs.config.server_name
def on_GET(self, request):
@@ -128,16 +127,16 @@ class LoginRestServlet(ClientV1RestServlet):
def do_cas_login(self, cas_response_body):
(user, attributes) = self.parse_cas_response(cas_response_body)
- if self.cas_required_attribute is not None:
+ for required_attribute in self.cas_required_attributes:
# If required attribute was not in CAS Response - Forbidden
- if self.cas_required_attribute not in attributes:
+ if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
# Also need to check value
- if self.cas_required_attribute_value is not None:
- actualValue = attributes[self.cas_required_attribute]
+ if self.cas_required_attributes[required_attribute] is not None:
+ actualValue = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
- if self.cas_required_attribute_value != actualValue:
+ if self.cas_required_attributes[required_attribute] != actualValue:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
|