summary refs log tree commit diff
path: root/tests/handlers/test_saml.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_saml.py')
-rw-r--r--tests/handlers/test_saml.py34
1 files changed, 33 insertions, 1 deletions
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 79fd47036f..e1e13a5faf 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -16,7 +16,7 @@ import attr
 
 from synapse.handlers.sso import MappingException
 
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, override_config
 
 # These are a few constants that are used as config parameters in the tests.
 BASE_URL = "https://synapse/"
@@ -59,6 +59,10 @@ class SamlHandlerTestCase(HomeserverTestCase):
             "grandfathered_mxid_source_attribute": None,
             "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
         }
+
+        # Update this config with what's in the default config so that
+        # override_config works as expected.
+        saml_config.update(config.get("saml2_config", {}))
         config["saml2_config"] = saml_config
 
         return config
@@ -86,6 +90,34 @@ class SamlHandlerTestCase(HomeserverTestCase):
         )
         self.assertEqual(mxid, "@test_user:test")
 
+    @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
+    def test_map_saml_response_to_existing_user(self):
+        """Existing users can log in with SAML account."""
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.register_user(user_id="@test_user:test", password_hash=None)
+        )
+
+        # Map a user via SSO.
+        saml_response = FakeAuthnResponse(
+            {"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
+        )
+        redirect_url = ""
+        mxid = self.get_success(
+            self.handler._map_saml_response_to_user(
+                saml_response, redirect_url, "user-agent", "10.10.10.10"
+            )
+        )
+        self.assertEqual(mxid, "@test_user:test")
+
+        # Subsequent calls should map to the same mxid.
+        mxid = self.get_success(
+            self.handler._map_saml_response_to_user(
+                saml_response, redirect_url, "user-agent", "10.10.10.10"
+            )
+        )
+        self.assertEqual(mxid, "@test_user:test")
+
     def test_map_saml_response_to_invalid_localpart(self):
         """If the mapping provider generates an invalid localpart it should be rejected."""
         saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})