diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 87f0c5e197..19cd652675 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -12,9 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import logging
-from typing import Dict, Generic, List, Optional, Tuple, TypeVar
+from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
from urllib.parse import urlencode
import attr
@@ -38,9 +37,11 @@ from synapse.config import ConfigError
from synapse.http.server import respond_with_html
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
-from synapse.push.mailer import load_jinja2_templates
-from synapse.server import HomeServer
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
+from synapse.util import json_decoder
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -91,7 +92,8 @@ class OidcHandler:
"""Handles requests related to the OpenID Connect login flow.
"""
- def __init__(self, hs: HomeServer):
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
self._client_auth = ClientAuth(
@@ -112,6 +114,7 @@ class OidcHandler:
hs.config.oidc_user_mapping_provider_config
) # type: OidcMappingProvider
self._skip_verification = hs.config.oidc_skip_verification # type: bool
+ self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool
self._http_client = hs.get_proxied_http_client()
self._auth_handler = hs.get_auth_handler()
@@ -121,9 +124,7 @@ class OidcHandler:
self._hostname = hs.hostname # type: str
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key
- self._error_template = load_jinja2_templates(
- hs.config.sso_template_dir, ["sso_error.html"]
- )[0]
+ self._error_template = hs.config.sso_error_template
# identifier for the external_ids table
self._auth_provider_id = "oidc"
@@ -131,10 +132,10 @@ class OidcHandler:
def _render_error(
self, request, error: str, error_description: Optional[str] = None
) -> None:
- """Renders the error template and respond with it.
+ """Render the error template and respond to the request with it.
This is used to show errors to the user. The template of this page can
- be found under ``synapse/res/templates/sso_error.html``.
+ be found under `synapse/res/templates/sso_error.html`.
Args:
request: The incoming request from the browser.
@@ -368,7 +369,7 @@ class OidcHandler:
# and check for an error field. If not, we respond with a generic
# error message.
try:
- resp = json.loads(resp_body.decode("utf-8"))
+ resp = json_decoder.decode(resp_body.decode("utf-8"))
error = resp["error"]
description = resp.get("error_description", error)
except (ValueError, KeyError):
@@ -385,7 +386,7 @@ class OidcHandler:
# Since it is a not a 5xx code, body should be a valid JSON. It will
# raise if not.
- resp = json.loads(resp_body.decode("utf-8"))
+ resp = json_decoder.decode(resp_body.decode("utf-8"))
if "error" in resp:
error = resp["error"]
@@ -690,14 +691,31 @@ class OidcHandler:
self._render_error(request, "invalid_token", str(e))
return
+ # Pull out the user-agent and IP from the request.
+ user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
+ 0
+ ].decode("ascii", "surrogateescape")
+ ip_address = self.hs.get_ip_from_request(request)
+
# Call the mapper to register/login the user
try:
- user_id = await self._map_userinfo_to_user(userinfo, token)
+ user_id = await self._map_userinfo_to_user(
+ userinfo, token, user_agent, ip_address
+ )
except MappingException as e:
logger.exception("Could not map user")
self._render_error(request, "mapping_error", str(e))
return
+ # Mapping providers might not have get_extra_attributes: only call this
+ # method if it exists.
+ extra_attributes = None
+ get_extra_attributes = getattr(
+ self._user_mapping_provider, "get_extra_attributes", None
+ )
+ if get_extra_attributes:
+ extra_attributes = await get_extra_attributes(userinfo, token)
+
# and finally complete the login
if ui_auth_session_id:
await self._auth_handler.complete_sso_ui_auth(
@@ -705,7 +723,7 @@ class OidcHandler:
)
else:
await self._auth_handler.complete_sso_login(
- user_id, request, client_redirect_url
+ user_id, request, client_redirect_url, extra_attributes
)
def _generate_oidc_session_token(
@@ -829,7 +847,9 @@ class OidcHandler:
now = self._clock.time_msec()
return now < expiry
- async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str:
+ async def _map_userinfo_to_user(
+ self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str
+ ) -> str:
"""Maps a UserInfo object to a mxid.
UserInfo should have a claim that uniquely identifies users. This claim
@@ -839,11 +859,14 @@ class OidcHandler:
If we don't find the user that way, we should register the user,
mapping the localpart and the display name from the UserInfo.
- If a user already exists with the mxid we've mapped, raise an exception.
+ If a user already exists with the mxid we've mapped and allow_existing_users
+ is disabled, raise an exception.
Args:
userinfo: an object representing the user
token: a dict with the tokens obtained from the provider
+ user_agent: The user agent of the client making the request.
+ ip_address: The IP address of the client making the request.
Raises:
MappingException: if there was an error while mapping some properties
@@ -857,6 +880,9 @@ class OidcHandler:
raise MappingException(
"Failed to extract subject from OIDC response: %s" % (e,)
)
+ # Some OIDC providers use integer IDs, but Synapse expects external IDs
+ # to be strings.
+ remote_user_id = str(remote_user_id)
logger.info(
"Looking for existing mapping for user %s:%s",
@@ -890,19 +916,31 @@ class OidcHandler:
localpart = map_username_to_mxid_localpart(attributes["localpart"])
- user_id = UserID(localpart, self._hostname)
- if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()):
- # This mxid is taken
- raise MappingException(
- "mxid '{}' is already taken".format(user_id.to_string())
+ user_id = UserID(localpart, self._hostname).to_string()
+ users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+ if users:
+ if self._allow_existing_users:
+ if len(users) == 1:
+ registered_user_id = next(iter(users))
+ elif user_id in users:
+ registered_user_id = user_id
+ else:
+ raise MappingException(
+ "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
+ user_id, list(users.keys())
+ )
+ )
+ else:
+ # This mxid is taken
+ raise MappingException("mxid '{}' is already taken".format(user_id))
+ else:
+ # It's the first time this user is logging in and the mapped mxid was
+ # not taken, register the user
+ registered_user_id = await self._registration_handler.register_user(
+ localpart=localpart,
+ default_display_name=attributes["display_name"],
+ user_agent_ips=(user_agent, ip_address),
)
-
- # It's the first time this user is logging in and the mapped mxid was
- # not taken, register the user
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart, default_display_name=attributes["display_name"],
- )
-
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id,
)
@@ -955,7 +993,7 @@ class OidcMappingProvider(Generic[C]):
async def map_user_attributes(
self, userinfo: UserInfo, token: Token
) -> UserAttribute:
- """Map a ``UserInfo`` objects into user attributes.
+ """Map a `UserInfo` object into user attributes.
Args:
userinfo: An object representing the user given by the OIDC provider
@@ -966,6 +1004,18 @@ class OidcMappingProvider(Generic[C]):
"""
raise NotImplementedError()
+ async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
+ """Map a `UserInfo` object into additional attributes passed to the client during login.
+
+ Args:
+ userinfo: An object representing the user given by the OIDC provider
+ token: A dict with the tokens returned by the provider
+
+ Returns:
+ A dict containing additional attributes. Must be JSON serializable.
+ """
+ return {}
+
# Used to clear out "None" values in templates
def jinja_finalize(thing):
@@ -980,6 +1030,7 @@ class JinjaOidcMappingConfig:
subject_claim = attr.ib() # type: str
localpart_template = attr.ib() # type: Template
display_name_template = attr.ib() # type: Optional[Template]
+ extra_attributes = attr.ib() # type: Dict[str, Template]
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
@@ -1018,10 +1069,28 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
% (e,)
)
+ extra_attributes = {} # type Dict[str, Template]
+ if "extra_attributes" in config:
+ extra_attributes_config = config.get("extra_attributes") or {}
+ if not isinstance(extra_attributes_config, dict):
+ raise ConfigError(
+ "oidc_config.user_mapping_provider.config.extra_attributes must be a dict"
+ )
+
+ for key, value in extra_attributes_config.items():
+ try:
+ extra_attributes[key] = env.from_string(value)
+ except Exception as e:
+ raise ConfigError(
+ "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r"
+ % (key, e)
+ )
+
return JinjaOidcMappingConfig(
subject_claim=subject_claim,
localpart_template=localpart_template,
display_name_template=display_name_template,
+ extra_attributes=extra_attributes,
)
def get_remote_user_id(self, userinfo: UserInfo) -> str:
@@ -1042,3 +1111,13 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
display_name = None
return UserAttribute(localpart=localpart, display_name=display_name)
+
+ async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
+ extras = {} # type: Dict[str, str]
+ for key, template in self._config.extra_attributes.items():
+ try:
+ extras[key] = template.render(user=userinfo).strip()
+ except Exception as e:
+ # Log an error and skip this value (don't break login for this).
+ logger.error("Failed to render OIDC extra attribute %s: %s" % (key, e))
+ return extras
|