summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/oidc_handler.py25
-rw-r--r--synapse/handlers/saml_handler.py6
-rw-r--r--synapse/types.py6
3 files changed, 29 insertions, 8 deletions
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index be8562d47b..4bfd8d5617 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -38,7 +38,12 @@ from synapse.handlers._base import BaseHandler
 from synapse.handlers.sso import MappingException
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
-from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
+from synapse.types import (
+    JsonDict,
+    UserID,
+    contains_invalid_mxid_characters,
+    map_username_to_mxid_localpart,
+)
 from synapse.util import json_decoder
 
 if TYPE_CHECKING:
@@ -885,10 +890,12 @@ class OidcHandler(BaseHandler):
             "Retrieved user attributes from user mapping provider: %r", attributes
         )
 
-        if not attributes["localpart"]:
-            raise MappingException("localpart is empty")
-
-        localpart = map_username_to_mxid_localpart(attributes["localpart"])
+        localpart = attributes["localpart"]
+        if not localpart:
+            raise MappingException(
+                "Error parsing OIDC response: OIDC mapping provider plugin "
+                "did not return a localpart value"
+            )
 
         user_id = UserID(localpart, self.server_name).to_string()
         users = await self.store.get_users_by_id_case_insensitive(user_id)
@@ -908,6 +915,11 @@ class OidcHandler(BaseHandler):
                 # This mxid is taken
                 raise MappingException("mxid '{}' is already taken".format(user_id))
         else:
+            # Since the localpart is provided via a potentially untrusted module,
+            # ensure the MXID is valid before registering.
+            if contains_invalid_mxid_characters(localpart):
+                raise MappingException("localpart is invalid: %s" % (localpart,))
+
             # It's the first time this user is logging in and the mapped mxid was
             # not taken, register the user
             registered_user_id = await self._registration_handler.register_user(
@@ -1076,6 +1088,9 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
     ) -> UserAttribute:
         localpart = self._config.localpart_template.render(user=userinfo).strip()
 
+        # Ensure only valid characters are included in the MXID.
+        localpart = map_username_to_mxid_localpart(localpart)
+
         display_name = None  # type: Optional[str]
         if self._config.display_name_template is not None:
             display_name = self._config.display_name_template.render(
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 9bf430b656..5d9b555b13 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -31,6 +31,7 @@ from synapse.http.site import SynapseRequest
 from synapse.module_api import ModuleApi
 from synapse.types import (
     UserID,
+    contains_invalid_mxid_characters,
     map_username_to_mxid_localpart,
     mxid_localpart_allowed_characters,
 )
@@ -318,6 +319,11 @@ class SamlHandler(BaseHandler):
                     "Unable to generate a Matrix ID from the SAML response"
                 )
 
+            # Since the localpart is provided via a potentially untrusted module,
+            # ensure the MXID is valid before registering.
+            if contains_invalid_mxid_characters(localpart):
+                raise MappingException("localpart is invalid: %s" % (localpart,))
+
             logger.info("Mapped SAML user to local part %s", localpart)
             registered_user_id = await self._registration_handler.register_user(
                 localpart=localpart,
diff --git a/synapse/types.py b/synapse/types.py
index 66bb5bac8d..3ab6bdbe06 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -317,14 +317,14 @@ mxid_localpart_allowed_characters = set(
 )
 
 
-def contains_invalid_mxid_characters(localpart):
+def contains_invalid_mxid_characters(localpart: str) -> bool:
     """Check for characters not allowed in an mxid or groupid localpart
 
     Args:
-        localpart (basestring): the localpart to be checked
+        localpart: the localpart to be checked
 
     Returns:
-        bool: True if there are any naughty characters
+        True if there are any naughty characters
     """
     return any(c not in mxid_localpart_allowed_characters for c in localpart)