summary refs log tree commit diff
path: root/synapse/handlers/oidc_handler.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/oidc_handler.py')
-rw-r--r--synapse/handlers/oidc_handler.py59
1 files changed, 27 insertions, 32 deletions
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index cbd11a1382..709f8dfc13 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -947,7 +947,7 @@ class OidcHandler(BaseHandler):
 
 
 UserAttributeDict = TypedDict(
-    "UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
+    "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
 )
 C = TypeVar("C")
 
@@ -1028,10 +1028,10 @@ env = Environment(finalize=jinja_finalize)
 
 @attr.s
 class JinjaOidcMappingConfig:
-    subject_claim = attr.ib()  # type: str
-    localpart_template = attr.ib()  # type: Template
-    display_name_template = attr.ib()  # type: Optional[Template]
-    extra_attributes = attr.ib()  # type: Dict[str, Template]
+    subject_claim = attr.ib(type=str)
+    localpart_template = attr.ib(type=Optional[Template])
+    display_name_template = attr.ib(type=Optional[Template])
+    extra_attributes = attr.ib(type=Dict[str, Template])
 
 
 class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
@@ -1047,18 +1047,14 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
     def parse_config(config: dict) -> JinjaOidcMappingConfig:
         subject_claim = config.get("subject_claim", "sub")
 
-        if "localpart_template" not in config:
-            raise ConfigError(
-                "missing key: oidc_config.user_mapping_provider.config.localpart_template"
-            )
-
-        try:
-            localpart_template = env.from_string(config["localpart_template"])
-        except Exception as e:
-            raise ConfigError(
-                "invalid jinja template for oidc_config.user_mapping_provider.config.localpart_template: %r"
-                % (e,)
-            )
+        localpart_template = None  # type: Optional[Template]
+        if "localpart_template" in config:
+            try:
+                localpart_template = env.from_string(config["localpart_template"])
+            except Exception as e:
+                raise ConfigError(
+                    "invalid jinja template", path=["localpart_template"]
+                ) from e
 
         display_name_template = None  # type: Optional[Template]
         if "display_name_template" in config:
@@ -1066,26 +1062,22 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
                 display_name_template = env.from_string(config["display_name_template"])
             except Exception as e:
                 raise ConfigError(
-                    "invalid jinja template for oidc_config.user_mapping_provider.config.display_name_template: %r"
-                    % (e,)
-                )
+                    "invalid jinja template", path=["display_name_template"]
+                ) from e
 
         extra_attributes = {}  # type Dict[str, Template]
         if "extra_attributes" in config:
             extra_attributes_config = config.get("extra_attributes") or {}
             if not isinstance(extra_attributes_config, dict):
-                raise ConfigError(
-                    "oidc_config.user_mapping_provider.config.extra_attributes must be a dict"
-                )
+                raise ConfigError("must be a dict", path=["extra_attributes"])
 
             for key, value in extra_attributes_config.items():
                 try:
                     extra_attributes[key] = env.from_string(value)
                 except Exception as e:
                     raise ConfigError(
-                        "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r"
-                        % (key, e)
-                    )
+                        "invalid jinja template", path=["extra_attributes", key]
+                    ) from e
 
         return JinjaOidcMappingConfig(
             subject_claim=subject_claim,
@@ -1100,14 +1092,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
     async def map_user_attributes(
         self, userinfo: UserInfo, token: Token, failures: int
     ) -> UserAttributeDict:
-        localpart = self._config.localpart_template.render(user=userinfo).strip()
+        localpart = None
+
+        if self._config.localpart_template:
+            localpart = self._config.localpart_template.render(user=userinfo).strip()
 
-        # Ensure only valid characters are included in the MXID.
-        localpart = map_username_to_mxid_localpart(localpart)
+            # Ensure only valid characters are included in the MXID.
+            localpart = map_username_to_mxid_localpart(localpart)
 
-        # Append suffix integer if last call to this function failed to produce
-        # a usable mxid.
-        localpart += str(failures) if failures else ""
+            # Append suffix integer if last call to this function failed to produce
+            # a usable mxid.
+            localpart += str(failures) if failures else ""
 
         display_name = None  # type: Optional[str]
         if self._config.display_name_template is not None: