summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8938.feature1
-rw-r--r--synapse/handlers/saml_handler.py23
-rw-r--r--tests/handlers/test_oidc.py12
-rw-r--r--tests/handlers/test_saml.py132
-rw-r--r--tests/test_utils/__init__.py12
5 files changed, 126 insertions, 54 deletions
diff --git a/changelog.d/8938.feature b/changelog.d/8938.feature
new file mode 100644
index 0000000000..d450ef4998
--- /dev/null
+++ b/changelog.d/8938.feature
@@ -0,0 +1 @@
+Add support for allowing users to pick their own user ID during a single-sign-on login.
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index f2ca1ddb53..6001fe3e27 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -163,6 +163,29 @@ class SamlHandler(BaseHandler):
             return
 
         logger.debug("SAML2 response: %s", saml2_auth.origxml)
+
+        await self._handle_authn_response(request, saml2_auth, relay_state)
+
+    async def _handle_authn_response(
+        self,
+        request: SynapseRequest,
+        saml2_auth: saml2.response.AuthnResponse,
+        relay_state: str,
+    ) -> None:
+        """Handle an AuthnResponse, having parsed it from the request params
+
+        Assumes that the signature on the response object has been checked. Maps
+        the user onto an MXID, registering them if necessary, and returns a response
+        to the browser.
+
+        Args:
+            request: the incoming request from the browser. We'll respond to it with an
+                HTML page or a redirect
+            saml2_auth: the parsed AuthnResponse object
+            relay_state: the RelayState query param, which encodes the URI to rediret
+               back to
+        """
+
         for assertion in saml2_auth.assertions:
             # kibana limits the length of a log field, whereas this is all rather
             # useful, so split it up.
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 9878527bab..464e569ac8 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -23,7 +23,7 @@ from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
 from synapse.handlers.sso import MappingException
 from synapse.types import UserID
 
-from tests.test_utils import FakeResponse
+from tests.test_utils import FakeResponse, simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
 
 # These are a few constants that are used as config parameters in the tests.
@@ -82,16 +82,6 @@ class TestMappingProviderFailures(TestMappingProvider):
         }
 
 
-def simple_async_mock(return_value=None, raises=None) -> Mock:
-    # AsyncMock is not available in python3.5, this mimics part of its behaviour
-    async def cb(*args, **kwargs):
-        if raises:
-            raise raises
-        return return_value
-
-    return Mock(side_effect=cb)
-
-
 async def get_json(url):
     # Mock get_json calls to handle jwks & oidc discovery endpoints
     if url == WELL_KNOWN:
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index d21e5588ca..69927cf6be 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -12,11 +12,15 @@
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
 
+from typing import Optional
+
+from mock import Mock
+
 import attr
 
 from synapse.api.errors import RedirectException
-from synapse.handlers.sso import MappingException
 
+from tests.test_utils import simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
 
 # Check if we have the dependencies to run the tests.
@@ -44,6 +48,8 @@ BASE_URL = "https://synapse/"
 @attr.s
 class FakeAuthnResponse:
     ava = attr.ib(type=dict)
+    assertions = attr.ib(type=list, factory=list)
+    in_response_to = attr.ib(type=Optional[str], default=None)
 
 
 class TestMappingProvider:
@@ -111,15 +117,22 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
     def test_map_saml_response_to_user(self):
         """Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        # send a mocked-up SAML response to the callback
         saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
-        # The redirect_url doesn't matter with the default user mapping provider.
-        redirect_url = ""
-        mxid = self.get_success(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            )
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri"
         )
-        self.assertEqual(mxid, "@test_user:test")
 
     @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
     def test_map_saml_response_to_existing_user(self):
@@ -129,53 +142,81 @@ class SamlHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id="@test_user:test", password_hash=None)
         )
 
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
         # Map a user via SSO.
         saml_response = FakeAuthnResponse(
             {"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
         )
-        redirect_url = ""
-        mxid = self.get_success(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            )
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, "")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, ""
         )
-        self.assertEqual(mxid, "@test_user:test")
 
         # Subsequent calls should map to the same mxid.
-        mxid = self.get_success(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            )
+        auth_handler.complete_sso_login.reset_mock()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, "")
+        )
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, ""
         )
-        self.assertEqual(mxid, "@test_user:test")
 
     def test_map_saml_response_to_invalid_localpart(self):
         """If the mapping provider generates an invalid localpart it should be rejected."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        # mock out the error renderer too
+        sso_handler = self.hs.get_sso_handler()
+        sso_handler.render_error = Mock(return_value=None)
+
         saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
-        redirect_url = ""
-        e = self.get_failure(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, ""),
+        )
+        sso_handler.render_error.assert_called_once_with(
+            request, "mapping_error", "localpart is invalid: föö"
         )
-        self.assertEqual(str(e.value), "localpart is invalid: föö")
+        auth_handler.complete_sso_login.assert_not_called()
 
     def test_map_saml_response_to_user_retries(self):
         """The mapping provider can retry generating an MXID if the MXID is already in use."""
+
+        # stub out the auth handler and error renderer
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+        sso_handler = self.hs.get_sso_handler()
+        sso_handler.render_error = Mock(return_value=None)
+
+        # register a user to occupy the first-choice MXID
         store = self.hs.get_datastore()
         self.get_success(
             store.register_user(user_id="@test_user:test", password_hash=None)
         )
+
+        # send the fake SAML response
         saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
-        redirect_url = ""
-        mxid = self.get_success(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            )
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, ""),
         )
+
         # test_user is already taken, so test_user1 gets registered instead.
-        self.assertEqual(mxid, "@test_user1:test")
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user1:test", request, ""
+        )
+        auth_handler.complete_sso_login.reset_mock()
 
         # Register all of the potential mxids for a particular SAML username.
         self.get_success(
@@ -188,15 +229,15 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # Now attempt to map to a username, this will fail since all potential usernames are taken.
         saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"})
-        e = self.get_failure(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, ""),
         )
-        self.assertEqual(
-            str(e.value), "Unable to generate a Matrix ID from the SSO response"
+        sso_handler.render_error.assert_called_once_with(
+            request,
+            "mapping_error",
+            "Unable to generate a Matrix ID from the SSO response",
         )
+        auth_handler.complete_sso_login.assert_not_called()
 
     @override_config(
         {
@@ -208,12 +249,17 @@ class SamlHandlerTestCase(HomeserverTestCase):
         }
     )
     def test_map_saml_response_redirect(self):
+        """Test a mapping provider that raises a RedirectException"""
+
         saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
-        redirect_url = ""
+        request = _mock_request()
         e = self.get_failure(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            ),
+            self.handler._handle_authn_response(request, saml_response, ""),
             RedirectException,
         )
         self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
+
+
+def _mock_request():
+    """Returns a mock which will stand in as a SynapseRequest"""
+    return Mock(spec=["getClientIP", "get_user_agent"])
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 6873d45eb6..43898d8142 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -22,6 +22,8 @@ import warnings
 from asyncio import Future
 from typing import Any, Awaitable, Callable, TypeVar
 
+from mock import Mock
+
 import attr
 
 from twisted.python.failure import Failure
@@ -87,6 +89,16 @@ def setup_awaitable_errors() -> Callable[[], None]:
     return cleanup
 
 
+def simple_async_mock(return_value=None, raises=None) -> Mock:
+    # AsyncMock is not available in python3.5, this mimics part of its behaviour
+    async def cb(*args, **kwargs):
+        if raises:
+            raise raises
+        return return_value
+
+    return Mock(side_effect=cb)
+
+
 @attr.s
 class FakeResponse:
     """A fake twisted.web.IResponse object