summary refs log tree commit diff
path: root/synapse/handlers/oidc_handler.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/oidc_handler.py')
-rw-r--r--synapse/handlers/oidc_handler.py102
1 files changed, 82 insertions, 20 deletions
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 1b06f3173f..19cd652675 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -37,7 +37,7 @@ from synapse.config import ConfigError
 from synapse.http.server import respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
 from synapse.util import json_decoder
 
 if TYPE_CHECKING:
@@ -114,6 +114,7 @@ class OidcHandler:
             hs.config.oidc_user_mapping_provider_config
         )  # type: OidcMappingProvider
         self._skip_verification = hs.config.oidc_skip_verification  # type: bool
+        self._allow_existing_users = hs.config.oidc_allow_existing_users  # type: bool
 
         self._http_client = hs.get_proxied_http_client()
         self._auth_handler = hs.get_auth_handler()
@@ -131,10 +132,10 @@ class OidcHandler:
     def _render_error(
         self, request, error: str, error_description: Optional[str] = None
     ) -> None:
-        """Renders the error template and respond with it.
+        """Render the error template and respond to the request with it.
 
         This is used to show errors to the user. The template of this page can
-        be found under ``synapse/res/templates/sso_error.html``.
+        be found under `synapse/res/templates/sso_error.html`.
 
         Args:
             request: The incoming request from the browser.
@@ -706,6 +707,15 @@ class OidcHandler:
             self._render_error(request, "mapping_error", str(e))
             return
 
+        # Mapping providers might not have get_extra_attributes: only call this
+        # method if it exists.
+        extra_attributes = None
+        get_extra_attributes = getattr(
+            self._user_mapping_provider, "get_extra_attributes", None
+        )
+        if get_extra_attributes:
+            extra_attributes = await get_extra_attributes(userinfo, token)
+
         # and finally complete the login
         if ui_auth_session_id:
             await self._auth_handler.complete_sso_ui_auth(
@@ -713,7 +723,7 @@ class OidcHandler:
             )
         else:
             await self._auth_handler.complete_sso_login(
-                user_id, request, client_redirect_url
+                user_id, request, client_redirect_url, extra_attributes
             )
 
     def _generate_oidc_session_token(
@@ -849,7 +859,8 @@ class OidcHandler:
         If we don't find the user that way, we should register the user,
         mapping the localpart and the display name from the UserInfo.
 
-        If a user already exists with the mxid we've mapped, raise an exception.
+        If a user already exists with the mxid we've mapped and allow_existing_users
+        is disabled, raise an exception.
 
         Args:
             userinfo: an object representing the user
@@ -905,21 +916,31 @@ class OidcHandler:
 
         localpart = map_username_to_mxid_localpart(attributes["localpart"])
 
-        user_id = UserID(localpart, self._hostname)
-        if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()):
-            # This mxid is taken
-            raise MappingException(
-                "mxid '{}' is already taken".format(user_id.to_string())
+        user_id = UserID(localpart, self._hostname).to_string()
+        users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+        if users:
+            if self._allow_existing_users:
+                if len(users) == 1:
+                    registered_user_id = next(iter(users))
+                elif user_id in users:
+                    registered_user_id = user_id
+                else:
+                    raise MappingException(
+                        "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
+                            user_id, list(users.keys())
+                        )
+                    )
+            else:
+                # This mxid is taken
+                raise MappingException("mxid '{}' is already taken".format(user_id))
+        else:
+            # It's the first time this user is logging in and the mapped mxid was
+            # not taken, register the user
+            registered_user_id = await self._registration_handler.register_user(
+                localpart=localpart,
+                default_display_name=attributes["display_name"],
+                user_agent_ips=(user_agent, ip_address),
             )
-
-        # It's the first time this user is logging in and the mapped mxid was
-        # not taken, register the user
-        registered_user_id = await self._registration_handler.register_user(
-            localpart=localpart,
-            default_display_name=attributes["display_name"],
-            user_agent_ips=(user_agent, ip_address),
-        )
-
         await self._datastore.record_user_external_id(
             self._auth_provider_id, remote_user_id, registered_user_id,
         )
@@ -972,7 +993,7 @@ class OidcMappingProvider(Generic[C]):
     async def map_user_attributes(
         self, userinfo: UserInfo, token: Token
     ) -> UserAttribute:
-        """Map a ``UserInfo`` objects into user attributes.
+        """Map a `UserInfo` object into user attributes.
 
         Args:
             userinfo: An object representing the user given by the OIDC provider
@@ -983,6 +1004,18 @@ class OidcMappingProvider(Generic[C]):
         """
         raise NotImplementedError()
 
+    async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
+        """Map a `UserInfo` object into additional attributes passed to the client during login.
+
+        Args:
+            userinfo: An object representing the user given by the OIDC provider
+            token: A dict with the tokens returned by the provider
+
+        Returns:
+            A dict containing additional attributes. Must be JSON serializable.
+        """
+        return {}
+
 
 # Used to clear out "None" values in templates
 def jinja_finalize(thing):
@@ -997,6 +1030,7 @@ class JinjaOidcMappingConfig:
     subject_claim = attr.ib()  # type: str
     localpart_template = attr.ib()  # type: Template
     display_name_template = attr.ib()  # type: Optional[Template]
+    extra_attributes = attr.ib()  # type: Dict[str, Template]
 
 
 class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
@@ -1035,10 +1069,28 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
                     % (e,)
                 )
 
+        extra_attributes = {}  # type Dict[str, Template]
+        if "extra_attributes" in config:
+            extra_attributes_config = config.get("extra_attributes") or {}
+            if not isinstance(extra_attributes_config, dict):
+                raise ConfigError(
+                    "oidc_config.user_mapping_provider.config.extra_attributes must be a dict"
+                )
+
+            for key, value in extra_attributes_config.items():
+                try:
+                    extra_attributes[key] = env.from_string(value)
+                except Exception as e:
+                    raise ConfigError(
+                        "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r"
+                        % (key, e)
+                    )
+
         return JinjaOidcMappingConfig(
             subject_claim=subject_claim,
             localpart_template=localpart_template,
             display_name_template=display_name_template,
+            extra_attributes=extra_attributes,
         )
 
     def get_remote_user_id(self, userinfo: UserInfo) -> str:
@@ -1059,3 +1111,13 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
                 display_name = None
 
         return UserAttribute(localpart=localpart, display_name=display_name)
+
+    async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
+        extras = {}  # type: Dict[str, str]
+        for key, template in self._config.extra_attributes.items():
+            try:
+                extras[key] = template.render(user=userinfo).strip()
+            except Exception as e:
+                # Log an error and skip this value (don't break login for this).
+                logger.error("Failed to render OIDC extra attribute %s: %s" % (key, e))
+        return extras