summary refs log tree commit diff
path: root/tests/handlers/test_cas.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-02-11 10:05:15 -0500
committerGitHub <noreply@github.com>2021-02-11 10:05:15 -0500
commit6dade80048380166ac7543d96c4d4686401b1e37 (patch)
tree31e9f226a6f77a701a5849878c2b0cffd71b89c6 /tests/handlers/test_cas.py
parentRemove conflicting sqlite tables that are "reserved" (shadow fts4 tables) (#9... (diff)
downloadsynapse-6dade80048380166ac7543d96c4d4686401b1e37.tar.xz
Combine the CAS & SAML implementations for required attributes. (#9326)
Diffstat (limited to 'tests/handlers/test_cas.py')
-rw-r--r--tests/handlers/test_cas.py52
1 files changed, 50 insertions, 2 deletions
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 7baf224f7e..6f992291b8 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -16,7 +16,7 @@ from mock import Mock
 from synapse.handlers.cas_handler import CasResponse
 
 from tests.test_utils import simple_async_mock
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, override_config
 
 # These are a few constants that are used as config parameters in the tests.
 BASE_URL = "https://synapse/"
@@ -32,6 +32,10 @@ class CasHandlerTestCase(HomeserverTestCase):
             "server_url": SERVER_URL,
             "service_url": BASE_URL,
         }
+
+        # Update this config with what's in the default config so that
+        # override_config works as expected.
+        cas_config.update(config.get("cas_config", {}))
         config["cas_config"] = cas_config
 
         return config
@@ -115,7 +119,51 @@ class CasHandlerTestCase(HomeserverTestCase):
             "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
         )
 
+    @override_config(
+        {
+            "cas_config": {
+                "required_attributes": {"userGroup": "staff", "department": None}
+            }
+        }
+    )
+    def test_required_attributes(self):
+        """The required attributes must be met from the CAS response."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        # The response doesn't have the proper userGroup or department.
+        cas_response = CasResponse("test_user", {})
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+        auth_handler.complete_sso_login.assert_not_called()
+
+        # The response doesn't have any department.
+        cas_response = CasResponse("test_user", {"userGroup": "staff"})
+        request.reset_mock()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+        auth_handler.complete_sso_login.assert_not_called()
+
+        # Add the proper attributes and it should succeed.
+        cas_response = CasResponse(
+            "test_user", {"userGroup": ["staff", "admin"], "department": ["sales"]}
+        )
+        request.reset_mock()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri", None, new_user=True
+        )
+
 
 def _mock_request():
     """Returns a mock which will stand in as a SynapseRequest"""
-    return Mock(spec=["getClientIP", "getHeader"])
+    return Mock(spec=["getClientIP", "getHeader", "_disconnected"])