summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8856.misc1
-rw-r--r--synapse/handlers/cas_handler.py112
-rw-r--r--synapse/handlers/sso.py4
-rw-r--r--tests/handlers/test_cas.py121
4 files changed, 199 insertions, 39 deletions
diff --git a/changelog.d/8856.misc b/changelog.d/8856.misc
new file mode 100644
index 0000000000..1507073e4f
--- /dev/null
+++ b/changelog.d/8856.misc
@@ -0,0 +1 @@
+Properly store the mapping of external ID to Matrix ID for CAS users.
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index e9891e1316..fca210a5a6 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -22,6 +22,7 @@ import attr
 from twisted.web.client import PartialDownloadError
 
 from synapse.api.errors import HttpResponseException
+from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.site import SynapseRequest
 from synapse.types import UserID, map_username_to_mxid_localpart
 
@@ -62,6 +63,7 @@ class CasHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self._hostname = hs.hostname
+        self._store = hs.get_datastore()
         self._auth_handler = hs.get_auth_handler()
         self._registration_handler = hs.get_registration_handler()
 
@@ -72,6 +74,9 @@ class CasHandler:
 
         self._http_client = hs.get_proxied_http_client()
 
+        # identifier for the external_ids table
+        self._auth_provider_id = "cas"
+
         self._sso_handler = hs.get_sso_handler()
 
     def _build_service_param(self, args: Dict[str, str]) -> str:
@@ -267,6 +272,14 @@ class CasHandler:
                 This should be the UI Auth session id.
         """
 
+        # first check if we're doing a UIA
+        if session:
+            return await self._sso_handler.complete_sso_ui_auth_request(
+                self._auth_provider_id, cas_response.username, session, request,
+            )
+
+        # otherwise, we're handling a login request.
+
         # Ensure that the attributes of the logged in user meet the required
         # attributes.
         for required_attribute, required_value in self._cas_required_attributes.items():
@@ -293,54 +306,79 @@ class CasHandler:
                     )
                     return
 
-        # Pull out the user-agent and IP from the request.
-        user_agent = request.get_user_agent("")
-        ip_address = self.hs.get_ip_from_request(request)
-
-        # Get the matrix ID from the CAS username.
-        user_id = await self._map_cas_user_to_matrix_user(
-            cas_response, user_agent, ip_address
-        )
+        # Call the mapper to register/login the user
 
-        if session:
-            await self._auth_handler.complete_sso_ui_auth(
-                user_id, session, request,
-            )
-        else:
-            # If this not a UI auth request than there must be a redirect URL.
-            assert client_redirect_url
+        # If this not a UI auth request than there must be a redirect URL.
+        assert client_redirect_url is not None
 
-            await self._auth_handler.complete_sso_login(
-                user_id, request, client_redirect_url
-            )
+        try:
+            await self._complete_cas_login(cas_response, request, client_redirect_url)
+        except MappingException as e:
+            logger.exception("Could not map user")
+            self._sso_handler.render_error(request, "mapping_error", str(e))
 
-    async def _map_cas_user_to_matrix_user(
-        self, cas_response: CasResponse, user_agent: str, ip_address: str,
-    ) -> str:
+    async def _complete_cas_login(
+        self,
+        cas_response: CasResponse,
+        request: SynapseRequest,
+        client_redirect_url: str,
+    ) -> None:
         """
-        Given a CAS username, retrieve the user ID for it and possibly register the user.
+        Given a CAS response, complete the login flow
+
+        Retrieves the remote user ID, registers the user if necessary, and serves
+        a redirect back to the client with a login-token.
 
         Args:
             cas_response: The parsed CAS response.
-            user_agent: The user agent of the client making the request.
-            ip_address: The IP address of the client making the request.
+            request: The request to respond to
+            client_redirect_url: The redirect URL passed in by the client.
 
-        Returns:
-             The user ID associated with this response.
+        Raises:
+            MappingException if there was a problem mapping the response to a user.
+            RedirectException: some mapping providers may raise this if they need
+                to redirect to an interstitial page.
         """
-
+        # Note that CAS does not support a mapping provider, so the logic is hard-coded.
         localpart = map_username_to_mxid_localpart(cas_response.username)
-        user_id = UserID(localpart, self._hostname).to_string()
-        registered_user_id = await self._auth_handler.check_user_exists(user_id)
 
-        displayname = cas_response.attributes.get(self._cas_displayname_attribute, None)
+        async def cas_response_to_user_attributes(failures: int) -> UserAttributes:
+            """
+            Map from CAS attributes to user attributes.
+            """
+            # Due to the grandfathering logic matching any previously registered
+            # mxids it isn't expected for there to be any failures.
+            if failures:
+                raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
+
+            display_name = cas_response.attributes.get(
+                self._cas_displayname_attribute, None
+            )
+
+            return UserAttributes(localpart=localpart, display_name=display_name)
 
-        # If the user does not exist, register it.
-        if not registered_user_id:
-            registered_user_id = await self._registration_handler.register_user(
-                localpart=localpart,
-                default_display_name=displayname,
-                user_agent_ips=[(user_agent, ip_address)],
+        async def grandfather_existing_users() -> Optional[str]:
+            # Since CAS did not always use the user_external_ids table, always
+            # to attempt to map to existing users.
+            user_id = UserID(localpart, self._hostname).to_string()
+
+            logger.debug(
+                "Looking for existing account based on mapped %s", user_id,
             )
 
-        return registered_user_id
+            users = await self._store.get_users_by_id_case_insensitive(user_id)
+            if users:
+                registered_user_id = list(users.keys())[0]
+                logger.info("Grandfathering mapping to %s", registered_user_id)
+                return registered_user_id
+
+            return None
+
+        await self._sso_handler.complete_sso_login_request(
+            self._auth_provider_id,
+            cas_response.username,
+            request,
+            client_redirect_url,
+            cas_response_to_user_attributes,
+            grandfather_existing_users,
+        )
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index b0a8c8c7d2..33cd6bc178 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -173,7 +173,7 @@ class SsoHandler:
         request: SynapseRequest,
         client_redirect_url: str,
         sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
-        grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
+        grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
         extra_login_attributes: Optional[JsonDict] = None,
     ) -> None:
         """
@@ -241,7 +241,7 @@ class SsoHandler:
             )
 
             # Check for grandfathering of users.
-            if not user_id and grandfather_existing_users:
+            if not user_id:
                 user_id = await grandfather_existing_users()
                 if user_id:
                     # Future logins should also match this user ID.
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
new file mode 100644
index 0000000000..bd7a1b6891
--- /dev/null
+++ b/tests/handlers/test_cas.py
@@ -0,0 +1,121 @@
+#  Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+from mock import Mock
+
+from synapse.handlers.cas_handler import CasResponse
+
+from tests.test_utils import simple_async_mock
+from tests.unittest import HomeserverTestCase
+
+# These are a few constants that are used as config parameters in the tests.
+BASE_URL = "https://synapse/"
+SERVER_URL = "https://issuer/"
+
+
+class CasHandlerTestCase(HomeserverTestCase):
+    def default_config(self):
+        config = super().default_config()
+        config["public_baseurl"] = BASE_URL
+        cas_config = {
+            "enabled": True,
+            "server_url": SERVER_URL,
+            "service_url": BASE_URL,
+        }
+        config["cas_config"] = cas_config
+
+        return config
+
+    def make_homeserver(self, reactor, clock):
+        hs = self.setup_test_homeserver()
+
+        self.handler = hs.get_cas_handler()
+
+        # Reduce the number of attempts when generating MXIDs.
+        sso_handler = hs.get_sso_handler()
+        sso_handler._MAP_USERNAME_RETRIES = 3
+
+        return hs
+
+    def test_map_cas_user_to_user(self):
+        """Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        cas_response = CasResponse("test_user", {})
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri", None
+        )
+
+    def test_map_cas_user_to_existing_user(self):
+        """Existing users can log in with CAS account."""
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.register_user(user_id="@test_user:test", password_hash=None)
+        )
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        # Map a user via SSO.
+        cas_response = CasResponse("test_user", {})
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri", None
+        )
+
+        # Subsequent calls should map to the same mxid.
+        auth_handler.complete_sso_login.reset_mock()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri", None
+        )
+
+    def test_map_cas_user_to_invalid_localpart(self):
+        """CAS automaps invalid characters to base-64 encoding."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        cas_response = CasResponse("föö", {})
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@f=c3=b6=c3=b6:test", request, "redirect_uri", None
+        )
+
+
+def _mock_request():
+    """Returns a mock which will stand in as a SynapseRequest"""
+    return Mock(spec=["getClientIP", "get_user_agent"])