summary refs log tree commit diff
path: root/synapse/handlers/saml_handler.py
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2019-09-13 15:20:49 +0100
committerRichard van der Hoff <richard@matrix.org>2019-09-13 16:01:46 +0100
commita8ac40445c98b9e1fc2538d7d4ec49c80b0298ac (patch)
treee506896dcaa02f826ffe0e5e1c859acabb290626 /synapse/handlers/saml_handler.py
parentMake the sample saml config closer to our standards (diff)
downloadsynapse-a8ac40445c98b9e1fc2538d7d4ec49c80b0298ac.tar.xz
Record mappings from saml users in an external table
We want to assign unique mxids to saml users based on an incrementing
suffix. For that to work, we need to record the allocated mxid in a separate
table.
Diffstat (limited to 'synapse/handlers/saml_handler.py')
-rw-r--r--synapse/handlers/saml_handler.py103
1 files changed, 95 insertions, 8 deletions
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index a1ce6929cf..5fa8272dc9 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -21,6 +21,8 @@ from saml2.client import Saml2Client
 from synapse.api.errors import SynapseError
 from synapse.http.servlet import parse_string
 from synapse.rest.client.v1.login import SSOAuthHandler
+from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.util.async_helpers import Linearizer
 
 logger = logging.getLogger(__name__)
 
@@ -29,12 +31,26 @@ class SamlHandler:
     def __init__(self, hs):
         self._saml_client = Saml2Client(hs.config.saml2_sp_config)
         self._sso_auth_handler = SSOAuthHandler(hs)
+        self._registration_handler = hs.get_registration_handler()
+
+        self._clock = hs.get_clock()
+        self._datastore = hs.get_datastore()
+        self._hostname = hs.hostname
+        self._saml2_session_lifetime = hs.config.saml2_session_lifetime
+        self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
+        self._grandfathered_mxid_source_attribute = (
+            hs.config.saml2_grandfathered_mxid_source_attribute
+        )
+        self._mxid_mapper = hs.config.saml2_mxid_mapper
+
+        # identifier for the external_ids table
+        self._auth_provider_id = "saml"
 
         # a map from saml session id to Saml2SessionData object
         self._outstanding_requests_dict = {}
 
-        self._clock = hs.get_clock()
-        self._saml2_session_lifetime = hs.config.saml2_session_lifetime
+        # a lock on the mappings
+        self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
 
     def handle_redirect_request(self, client_redirect_url):
         """Handle an incoming request to /login/sso/redirect
@@ -60,7 +76,7 @@ class SamlHandler:
         # this shouldn't happen!
         raise Exception("prepare_for_authenticate didn't return a Location header")
 
-    def handle_saml_response(self, request):
+    async def handle_saml_response(self, request):
         """Handle an incoming request to /_matrix/saml2/authn_response
 
         Args:
@@ -77,6 +93,10 @@ class SamlHandler:
         # the dict.
         self.expire_sessions()
 
+        user_id = await self._map_saml_response_to_user(resp_bytes)
+        self._sso_auth_handler.complete_sso_login(user_id, request, relay_state)
+
+    async def _map_saml_response_to_user(self, resp_bytes):
         try:
             saml2_auth = self._saml_client.parse_authn_request_response(
                 resp_bytes,
@@ -91,18 +111,85 @@ class SamlHandler:
             logger.warning("SAML2 response was not signed")
             raise SynapseError(400, "SAML2 response was not signed")
 
-        if "uid" not in saml2_auth.ava:
+        try:
+            remote_user_id = saml2_auth.ava["uid"][0]
+        except KeyError:
             logger.warning("SAML2 response lacks a 'uid' attestation")
             raise SynapseError(400, "uid not in SAML2 response")
 
+        try:
+            mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
+        except KeyError:
+            logger.warning(
+                "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
+            )
+            raise SynapseError(
+                400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
+            )
+
         self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
 
-        username = saml2_auth.ava["uid"][0]
         displayName = saml2_auth.ava.get("displayName", [None])[0]
 
-        return self._sso_auth_handler.on_successful_auth(
-            username, request, relay_state, user_display_name=displayName
-        )
+        with (await self._mapping_lock.queue(self._auth_provider_id)):
+            # first of all, check if we already have a mapping for this user
+            logger.info(
+                "Looking for existing mapping for user %s:%s",
+                self._auth_provider_id,
+                remote_user_id,
+            )
+            registered_user_id = await self._datastore.get_user_by_external_id(
+                self._auth_provider_id, remote_user_id
+            )
+            if registered_user_id is not None:
+                logger.info("Found existing mapping %s", registered_user_id)
+                return registered_user_id
+
+            # backwards-compatibility hack: see if there is an existing user with a
+            # suitable mapping from the uid
+            if (
+                self._grandfathered_mxid_source_attribute
+                and self._grandfathered_mxid_source_attribute in saml2_auth.ava
+            ):
+                attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
+                user_id = UserID(
+                    map_username_to_mxid_localpart(attrval), self._hostname
+                ).to_string()
+                logger.info(
+                    "Looking for existing account based on mapped %s %s",
+                    self._grandfathered_mxid_source_attribute,
+                    user_id,
+                )
+
+                users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+                if users:
+                    registered_user_id = list(users.keys())[0]
+                    logger.info("Grandfathering mapping to %s", registered_user_id)
+                    await self._datastore.record_user_external_id(
+                        self._auth_provider_id, remote_user_id, registered_user_id
+                    )
+                    return registered_user_id
+
+            # figure out a new mxid for this user
+            base_mxid_localpart = self._mxid_mapper(mxid_source)
+
+            suffix = 0
+            while True:
+                localpart = base_mxid_localpart + (str(suffix) if suffix else "")
+                if not await self._datastore.get_users_by_id_case_insensitive(
+                    UserID(localpart, self._hostname).to_string()
+                ):
+                    break
+                suffix += 1
+            logger.info("Allocating mxid for new user with localpart %s", localpart)
+
+            registered_user_id = await self._registration_handler.register_user(
+                localpart=localpart, default_display_name=displayName
+            )
+            await self._datastore.record_user_external_id(
+                self._auth_provider_id, remote_user_id, registered_user_id
+            )
+            return registered_user_id
 
     def expire_sessions(self):
         expire_before = self._clock.time_msec() - self._saml2_session_lifetime