summary refs log tree commit diff
diff options
context:
space:
mode:
authorAzrenbeth <7782548+Azrenbeth@users.noreply.github.com>2021-08-24 11:34:26 +0100
committerAzrenbeth <7782548+Azrenbeth@users.noreply.github.com>2021-08-24 14:38:22 +0100
commit162738feb6a5ba049a56926e560880158624a9a9 (patch)
treece930a58b68dec81a90f5dac34ae46059e570891
parentPort the saml mapping providers to new module interface (diff)
downloadsynapse-162738feb6a5ba049a56926e560880158624a9a9.tar.xz
Updated tests to use new module system
-rw-r--r--tests/handlers/test_saml.py178
1 files changed, 174 insertions, 4 deletions
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 8cfc184fef..4df6a4d029 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -18,6 +18,7 @@ from unittest.mock import Mock
 import attr
 
 from synapse.api.errors import RedirectException
+from synapse.handlers.saml import load_default_or_legacy_saml2_mapping_provider
 
 from tests.test_utils import simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
@@ -51,7 +52,7 @@ class FakeAuthnResponse:
     in_response_to = attr.ib(type=Optional[str], default=None)
 
 
-class TestMappingProvider:
+class LegacyTestMappingProvider:
     def __init__(self, config, module):
         pass
 
@@ -73,6 +74,31 @@ class TestMappingProvider:
         return {"mxid_localpart": localpart, "displayname": None}
 
 
+class LegacyTestRedirectMappingProvider(LegacyTestMappingProvider):
+    def saml_response_to_user_attributes(
+        self, saml_response, failures, client_redirect_url
+    ):
+        raise RedirectException(b"https://custom-saml-redirect/")
+
+
+class TestMappingProvider:
+    def __init__(self, config, api):
+        api.register_saml2_user_mapping_provider_callbacks(
+            get_remote_user_id=self.get_remote_user_id,
+            saml_response_to_user_attributes=self.saml_response_to_user_attributes,
+            saml_attributes=({"uid"}, {"displayName"}),
+        )
+
+    async def get_remote_user_id(self, saml_response, client_redirect_url):
+        return saml_response.ava["uid"]
+
+    async def saml_response_to_user_attributes(
+        self, saml_response, failures, client_redirect_url
+    ):
+        localpart = saml_response.ava["username"] + (str(failures) if failures else "")
+        return {"mxid_localpart": localpart, "displayname": None}
+
+
 class TestRedirectMappingProvider(TestMappingProvider):
     def saml_response_to_user_attributes(
         self, saml_response, failures, client_redirect_url
@@ -88,7 +114,6 @@ class SamlHandlerTestCase(HomeserverTestCase):
             "sp_config": {"metadata": {}},
             # Disable grandfathering.
             "grandfathered_mxid_source_attribute": None,
-            "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
         }
 
         # Update this config with what's in the default config so that
@@ -101,6 +126,13 @@ class SamlHandlerTestCase(HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         hs = self.setup_test_homeserver()
 
+        module_api = hs.get_module_api()
+        for module, config in hs.config.modules.loaded_modules:
+            module(config=config, api=module_api)
+
+        if not hs.get_saml2_user_mapping_provider().module_has_registered:
+            load_default_or_legacy_saml2_mapping_provider(hs)
+
         self.handler = hs.get_saml_handler()
 
         # Reduce the number of attempts when generating MXIDs.
@@ -114,7 +146,31 @@ class SamlHandlerTestCase(HomeserverTestCase):
     elif not has_xmlsec1:
         skip = "Requires xmlsec1"
 
+    @override_config(
+        {
+            "saml2_config": {
+                "user_mapping_provider": {
+                    "module": __name__ + ".LegacyTestMappingProvider"
+                },
+            }
+        }
+    )
+    def test_map_saml_response_to_user_legacy(self):
+        self.map_saml_response_to_user_body()
+
+    @override_config(
+        {
+            "modules": [
+                {
+                    "module": __name__ + ".TestMappingProvider",
+                }
+            ]
+        }
+    )
     def test_map_saml_response_to_user(self):
+        self.map_saml_response_to_user_body()
+
+    def map_saml_response_to_user_body(self):
         """Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
 
         # stub out the auth handler
@@ -133,8 +189,35 @@ class SamlHandlerTestCase(HomeserverTestCase):
             "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
         )
 
-    @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
+    @override_config(
+        {
+            "saml2_config": {
+                "user_mapping_provider": {
+                    "module": __name__ + ".LegacyTestMappingProvider"
+                },
+                "grandfathered_mxid_source_attribute": "mxid",
+            }
+        }
+    )
+    def test_map_saml_response_to_existing_user_legacy(self):
+        self.map_saml_response_to_existing_user_body()
+
+    @override_config(
+        {
+            "modules": [
+                {
+                    "module": __name__ + ".TestMappingProvider",
+                }
+            ],
+            "saml2_config": {
+                "grandfathered_mxid_source_attribute": "mxid",
+            },
+        }
+    )
     def test_map_saml_response_to_existing_user(self):
+        self.map_saml_response_to_existing_user_body()
+
+    def map_saml_response_to_existing_user_body(self):
         """Existing users can log in with SAML account."""
         store = self.hs.get_datastore()
         self.get_success(
@@ -168,7 +251,31 @@ class SamlHandlerTestCase(HomeserverTestCase):
             "@test_user:test", "saml", request, "", None, new_user=False
         )
 
+    @override_config(
+        {
+            "saml2_config": {
+                "user_mapping_provider": {
+                    "module": __name__ + ".LegacyTestMappingProvider"
+                },
+            }
+        }
+    )
+    def test_map_saml_response_to_invalid_localpart_legacy(self):
+        self.map_saml_response_to_invalid_localpart_body()
+
+    @override_config(
+        {
+            "modules": [
+                {
+                    "module": __name__ + ".TestMappingProvider",
+                }
+            ]
+        }
+    )
     def test_map_saml_response_to_invalid_localpart(self):
+        self.map_saml_response_to_invalid_localpart_body()
+
+    def map_saml_response_to_invalid_localpart_body(self):
         """If the mapping provider generates an invalid localpart it should be rejected."""
 
         # stub out the auth handler
@@ -189,7 +296,31 @@ class SamlHandlerTestCase(HomeserverTestCase):
         )
         auth_handler.complete_sso_login.assert_not_called()
 
+    @override_config(
+        {
+            "saml2_config": {
+                "user_mapping_provider": {
+                    "module": __name__ + ".LegacyTestMappingProvider"
+                },
+            }
+        }
+    )
+    def test_map_saml_response_to_user_retries_legacy(self):
+        self.map_saml_response_to_user_retries_body()
+
+    @override_config(
+        {
+            "modules": [
+                {
+                    "module": __name__ + ".TestMappingProvider",
+                }
+            ]
+        }
+    )
     def test_map_saml_response_to_user_retries(self):
+        self.map_saml_response_to_user_retries_body()
+
+    def map_saml_response_to_user_retries_body(self):
         """The mapping provider can retry generating an MXID if the MXID is already in use."""
 
         # stub out the auth handler and error renderer
@@ -242,12 +373,27 @@ class SamlHandlerTestCase(HomeserverTestCase):
         {
             "saml2_config": {
                 "user_mapping_provider": {
-                    "module": __name__ + ".TestRedirectMappingProvider"
+                    "module": __name__ + ".LegacyTestRedirectMappingProvider"
                 },
             }
         }
     )
+    def test_map_saml_response_redirect_legacy(self):
+        self.map_saml_response_redirect_body()
+
+    @override_config(
+        {
+            "modules": [
+                {
+                    "module": __name__ + ".TestRedirectMappingProvider",
+                }
+            ]
+        }
+    )
     def test_map_saml_response_redirect(self):
+        self.map_saml_response_redirect_body()
+
+    def map_saml_response_redirect_body(self):
         """Test a mapping provider that raises a RedirectException"""
 
         saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
@@ -261,6 +407,27 @@ class SamlHandlerTestCase(HomeserverTestCase):
     @override_config(
         {
             "saml2_config": {
+                "user_mapping_provider": {
+                    "module": __name__ + ".LegacyTestMappingProvider"
+                },
+                "attribute_requirements": [
+                    {"attribute": "userGroup", "value": "staff"},
+                    {"attribute": "department", "value": "sales"},
+                ],
+            },
+        }
+    )
+    def test_attribute_requirements_legacy(self):
+        self.attribute_requirements_body()
+
+    @override_config(
+        {
+            "modules": [
+                {
+                    "module": __name__ + ".TestMappingProvider",
+                }
+            ],
+            "saml2_config": {
                 "attribute_requirements": [
                     {"attribute": "userGroup", "value": "staff"},
                     {"attribute": "department", "value": "sales"},
@@ -269,6 +436,9 @@ class SamlHandlerTestCase(HomeserverTestCase):
         }
     )
     def test_attribute_requirements(self):
+        self.attribute_requirements_body()
+
+    def attribute_requirements_body(self):
         """The required attributes must be met from the SAML response."""
 
         # stub out the auth handler