summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/config/cas.py15
-rw-r--r--synapse/rest/client/v1/login.py16
2 files changed, 30 insertions, 1 deletions
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index 81d034e8f0..4d1dd8cc7b 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -27,13 +27,28 @@ class CasConfig(Config):
         if cas_config:
             self.cas_enabled = True
             self.cas_server_url = cas_config["server_url"]
+
+            if "required_attribute" in cas_config:
+                self.cas_required_attribute = cas_config["required_attribute"]
+            else:
+                self.cas_required_attribute = None
+
+            if "required_attribute_value" in cas_config:
+                self.cas_required_attribute_value = cas_config["required_attribute_value"]
+            else:
+                self.cas_required_attribute_value = None
+
         else:
             self.cas_enabled = False
             self.cas_server_url = None
+            self.cas_required_attribute = None
+            self.cas_required_attribute_value = None
 
     def default_config(self, config_dir_path, server_name, **kwargs):
         return """
         # Enable CAS for registration and login.
         #cas_config:
         #   server_url: "https://cas-server.com"
+        #   #required_attribute: something
+        #   #required_attribute_value: true
         """
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 0e12880ab5..1e62beaff8 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -45,8 +45,9 @@ class LoginRestServlet(ClientV1RestServlet):
         self.idp_redirect_url = hs.config.saml2_idp_redirect_url
         self.saml2_enabled = hs.config.saml2_enabled
         self.cas_enabled = hs.config.cas_enabled
-
         self.cas_server_url = hs.config.cas_server_url
+        self.cas_required_attribute = hs.config.cas_required_attribute
+        self.cas_required_attribute_value = hs.config.cas_required_attribute_value
         self.servername = hs.config.server_name
 
     def on_GET(self, request):
@@ -126,6 +127,19 @@ class LoginRestServlet(ClientV1RestServlet):
     @defer.inlineCallbacks
     def do_cas_login(self, cas_response_body):
         (user, attributes) = self.parse_cas_response(cas_response_body)
+
+        if self.cas_required_attribute is not None:
+            # If required attribute was not in CAS Response - Forbidden
+            if self.cas_required_attribute not in attributes:
+                raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
+
+            # Also need to check value
+            if self.cas_required_attribute_value is not None:
+                actualValue = attributes[self.cas_required_attribute]
+                # If required attribute value does not match expected - Forbidden
+                if self.cas_required_attribute_value != actualValue:
+                    raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
+
         user_id = UserID.create(user, self.hs.hostname).to_string()
         auth_handler = self.handlers.auth_handler
         user_exists = yield auth_handler.does_user_exist(user_id)