summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/saml_handler.py103
1 files changed, 95 insertions, 8 deletions
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index a1ce6929cf..5fa8272dc9 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -21,6 +21,8 @@ from saml2.client import Saml2Client
 from synapse.api.errors import SynapseError
 from synapse.http.servlet import parse_string
 from synapse.rest.client.v1.login import SSOAuthHandler
+from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.util.async_helpers import Linearizer
 
 logger = logging.getLogger(__name__)
 
@@ -29,12 +31,26 @@ class SamlHandler:
     def __init__(self, hs):
         self._saml_client = Saml2Client(hs.config.saml2_sp_config)
         self._sso_auth_handler = SSOAuthHandler(hs)
+        self._registration_handler = hs.get_registration_handler()
+
+        self._clock = hs.get_clock()
+        self._datastore = hs.get_datastore()
+        self._hostname = hs.hostname
+        self._saml2_session_lifetime = hs.config.saml2_session_lifetime
+        self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
+        self._grandfathered_mxid_source_attribute = (
+            hs.config.saml2_grandfathered_mxid_source_attribute
+        )
+        self._mxid_mapper = hs.config.saml2_mxid_mapper
+
+        # identifier for the external_ids table
+        self._auth_provider_id = "saml"
 
         # a map from saml session id to Saml2SessionData object
         self._outstanding_requests_dict = {}
 
-        self._clock = hs.get_clock()
-        self._saml2_session_lifetime = hs.config.saml2_session_lifetime
+        # a lock on the mappings
+        self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
 
     def handle_redirect_request(self, client_redirect_url):
         """Handle an incoming request to /login/sso/redirect
@@ -60,7 +76,7 @@ class SamlHandler:
         # this shouldn't happen!
         raise Exception("prepare_for_authenticate didn't return a Location header")
 
-    def handle_saml_response(self, request):
+    async def handle_saml_response(self, request):
         """Handle an incoming request to /_matrix/saml2/authn_response
 
         Args:
@@ -77,6 +93,10 @@ class SamlHandler:
         # the dict.
         self.expire_sessions()
 
+        user_id = await self._map_saml_response_to_user(resp_bytes)
+        self._sso_auth_handler.complete_sso_login(user_id, request, relay_state)
+
+    async def _map_saml_response_to_user(self, resp_bytes):
         try:
             saml2_auth = self._saml_client.parse_authn_request_response(
                 resp_bytes,
@@ -91,18 +111,85 @@ class SamlHandler:
             logger.warning("SAML2 response was not signed")
             raise SynapseError(400, "SAML2 response was not signed")
 
-        if "uid" not in saml2_auth.ava:
+        try:
+            remote_user_id = saml2_auth.ava["uid"][0]
+        except KeyError:
             logger.warning("SAML2 response lacks a 'uid' attestation")
             raise SynapseError(400, "uid not in SAML2 response")
 
+        try:
+            mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
+        except KeyError:
+            logger.warning(
+                "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
+            )
+            raise SynapseError(
+                400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
+            )
+
         self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
 
-        username = saml2_auth.ava["uid"][0]
         displayName = saml2_auth.ava.get("displayName", [None])[0]
 
-        return self._sso_auth_handler.on_successful_auth(
-            username, request, relay_state, user_display_name=displayName
-        )
+        with (await self._mapping_lock.queue(self._auth_provider_id)):
+            # first of all, check if we already have a mapping for this user
+            logger.info(
+                "Looking for existing mapping for user %s:%s",
+                self._auth_provider_id,
+                remote_user_id,
+            )
+            registered_user_id = await self._datastore.get_user_by_external_id(
+                self._auth_provider_id, remote_user_id
+            )
+            if registered_user_id is not None:
+                logger.info("Found existing mapping %s", registered_user_id)
+                return registered_user_id
+
+            # backwards-compatibility hack: see if there is an existing user with a
+            # suitable mapping from the uid
+            if (
+                self._grandfathered_mxid_source_attribute
+                and self._grandfathered_mxid_source_attribute in saml2_auth.ava
+            ):
+                attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
+                user_id = UserID(
+                    map_username_to_mxid_localpart(attrval), self._hostname
+                ).to_string()
+                logger.info(
+                    "Looking for existing account based on mapped %s %s",
+                    self._grandfathered_mxid_source_attribute,
+                    user_id,
+                )
+
+                users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+                if users:
+                    registered_user_id = list(users.keys())[0]
+                    logger.info("Grandfathering mapping to %s", registered_user_id)
+                    await self._datastore.record_user_external_id(
+                        self._auth_provider_id, remote_user_id, registered_user_id
+                    )
+                    return registered_user_id
+
+            # figure out a new mxid for this user
+            base_mxid_localpart = self._mxid_mapper(mxid_source)
+
+            suffix = 0
+            while True:
+                localpart = base_mxid_localpart + (str(suffix) if suffix else "")
+                if not await self._datastore.get_users_by_id_case_insensitive(
+                    UserID(localpart, self._hostname).to_string()
+                ):
+                    break
+                suffix += 1
+            logger.info("Allocating mxid for new user with localpart %s", localpart)
+
+            registered_user_id = await self._registration_handler.register_user(
+                localpart=localpart, default_display_name=displayName
+            )
+            await self._datastore.record_user_external_id(
+                self._auth_provider_id, remote_user_id, registered_user_id
+            )
+            return registered_user_id
 
     def expire_sessions(self):
         expire_before = self._clock.time_msec() - self._saml2_session_lifetime