diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 19cd652675..c605f7082a 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
from urllib.parse import urlencode
@@ -34,7 +35,8 @@ from typing_extensions import TypedDict
from twisted.web.client import readBody
from synapse.config import ConfigError
-from synapse.http.server import respond_with_html
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
@@ -83,19 +85,15 @@ class OidcError(Exception):
return self.error
-class MappingException(Exception):
- """Used to catch errors when mapping the UserInfo object
- """
-
-
-class OidcHandler:
+class OidcHandler(BaseHandler):
"""Handles requests related to the OpenID Connect login flow.
"""
def __init__(self, hs: "HomeServer"):
- self.hs = hs
+ super().__init__(hs)
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
+ self._user_profile_method = hs.config.oidc_user_profile_method # type: str
self._client_auth = ClientAuth(
hs.config.oidc_client_id,
hs.config.oidc_client_secret,
@@ -119,36 +117,13 @@ class OidcHandler:
self._http_client = hs.get_proxied_http_client()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
- self._datastore = hs.get_datastore()
- self._clock = hs.get_clock()
- self._hostname = hs.hostname # type: str
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key
- self._error_template = hs.config.sso_error_template
# identifier for the external_ids table
self._auth_provider_id = "oidc"
- def _render_error(
- self, request, error: str, error_description: Optional[str] = None
- ) -> None:
- """Render the error template and respond to the request with it.
-
- This is used to show errors to the user. The template of this page can
- be found under `synapse/res/templates/sso_error.html`.
-
- Args:
- request: The incoming request from the browser.
- We'll respond with an HTML page describing the error.
- error: A technical identifier for this error. Those include
- well-known OAuth2/OIDC error types like invalid_request or
- access_denied.
- error_description: A human-readable description of the error.
- """
- html = self._error_template.render(
- error=error, error_description=error_description
- )
- respond_with_html(request, 400, html)
+ self._sso_handler = hs.get_sso_handler()
def _validate_metadata(self):
"""Verifies the provider metadata.
@@ -196,11 +171,11 @@ class OidcHandler:
% (m["response_types_supported"],)
)
- # If the openid scope was not requested, we need a userinfo endpoint to fetch user infos
+ # Ensure there's a userinfo endpoint to fetch from if it is required.
if self._uses_userinfo:
if m.get("userinfo_endpoint") is None:
raise ValueError(
- 'provider has no "userinfo_endpoint", even though it is required because the "openid" scope is not requested'
+ 'provider has no "userinfo_endpoint", even though it is required'
)
else:
# If we're not using userinfo, we need a valid jwks to validate the ID token
@@ -216,12 +191,14 @@ class OidcHandler:
This is based on the requested scopes: if the scopes include
``openid``, the provider should give use an ID token containing the
- user informations. If not, we should fetch them using the
+ user information. If not, we should fetch them using the
``access_token`` with the ``userinfo_endpoint``.
"""
- # Maybe that should be user-configurable and not inferred?
- return "openid" not in self._scopes
+ return (
+ "openid" not in self._scopes
+ or self._user_profile_method == "userinfo_endpoint"
+ )
async def load_metadata(self) -> OpenIDProviderMetadata:
"""Load and validate the provider metadata.
@@ -423,7 +400,7 @@ class OidcHandler:
return resp
async def _fetch_userinfo(self, token: Token) -> UserInfo:
- """Fetch user informations from the ``userinfo_endpoint``.
+ """Fetch user information from the ``userinfo_endpoint``.
Args:
token: the token given by the ``token_endpoint``.
@@ -568,7 +545,7 @@ class OidcHandler:
Since we might want to display OIDC-related errors in a user-friendly
way, we don't raise SynapseError from here. Instead, we call
- ``self._render_error`` which displays an HTML page for the error.
+ ``self._sso_handler.render_error`` which displays an HTML page for the error.
Most of the OpenID Connect logic happens here:
@@ -606,7 +583,7 @@ class OidcHandler:
if error != "access_denied":
logger.error("Error from the OIDC provider: %s %s", error, description)
- self._render_error(request, error, description)
+ self._sso_handler.render_error(request, error, description)
return
# otherwise, it is presumably a successful response. see:
@@ -616,7 +593,9 @@ class OidcHandler:
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
if session is None:
logger.info("No session cookie found")
- self._render_error(request, "missing_session", "No session cookie found")
+ self._sso_handler.render_error(
+ request, "missing_session", "No session cookie found"
+ )
return
# Remove the cookie. There is a good chance that if the callback failed
@@ -634,7 +613,9 @@ class OidcHandler:
# Check for the state query parameter
if b"state" not in request.args:
logger.info("State parameter is missing")
- self._render_error(request, "invalid_request", "State parameter is missing")
+ self._sso_handler.render_error(
+ request, "invalid_request", "State parameter is missing"
+ )
return
state = request.args[b"state"][0].decode()
@@ -648,17 +629,19 @@ class OidcHandler:
) = self._verify_oidc_session_token(session, state)
except MacaroonDeserializationException as e:
logger.exception("Invalid session")
- self._render_error(request, "invalid_session", str(e))
+ self._sso_handler.render_error(request, "invalid_session", str(e))
return
except MacaroonInvalidSignatureException as e:
logger.exception("Could not verify session")
- self._render_error(request, "mismatching_session", str(e))
+ self._sso_handler.render_error(request, "mismatching_session", str(e))
return
# Exchange the code with the provider
if b"code" not in request.args:
logger.info("Code parameter is missing")
- self._render_error(request, "invalid_request", "Code parameter is missing")
+ self._sso_handler.render_error(
+ request, "invalid_request", "Code parameter is missing"
+ )
return
logger.debug("Exchanging code")
@@ -667,7 +650,7 @@ class OidcHandler:
token = await self._exchange_code(code)
except OidcError as e:
logger.exception("Could not exchange code")
- self._render_error(request, e.error, e.error_description)
+ self._sso_handler.render_error(request, e.error, e.error_description)
return
logger.debug("Successfully obtained OAuth2 access token")
@@ -680,7 +663,7 @@ class OidcHandler:
userinfo = await self._fetch_userinfo(token)
except Exception as e:
logger.exception("Could not fetch userinfo")
- self._render_error(request, "fetch_error", str(e))
+ self._sso_handler.render_error(request, "fetch_error", str(e))
return
else:
logger.debug("Extracting userinfo from id_token")
@@ -688,13 +671,11 @@ class OidcHandler:
userinfo = await self._parse_id_token(token, nonce=nonce)
except Exception as e:
logger.exception("Invalid id_token")
- self._render_error(request, "invalid_token", str(e))
+ self._sso_handler.render_error(request, "invalid_token", str(e))
return
# Pull out the user-agent and IP from the request.
- user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
- 0
- ].decode("ascii", "surrogateescape")
+ user_agent = request.get_user_agent("")
ip_address = self.hs.get_ip_from_request(request)
# Call the mapper to register/login the user
@@ -704,7 +685,7 @@ class OidcHandler:
)
except MappingException as e:
logger.exception("Could not map user")
- self._render_error(request, "mapping_error", str(e))
+ self._sso_handler.render_error(request, "mapping_error", str(e))
return
# Mapping providers might not have get_extra_attributes: only call this
@@ -753,7 +734,7 @@ class OidcHandler:
Defaults to an hour.
Returns:
- A signed macaroon token with the session informations.
+ A signed macaroon token with the session information.
"""
macaroon = pymacaroons.Macaroon(
location=self._server_name, identifier="key", key=self._macaroon_secret_key,
@@ -769,7 +750,7 @@ class OidcHandler:
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (ui_auth_session_id,)
)
- now = self._clock.time_msec()
+ now = self.clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
@@ -844,7 +825,7 @@ class OidcHandler:
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
- now = self._clock.time_msec()
+ now = self.clock.time_msec()
return now < expiry
async def _map_userinfo_to_user(
@@ -884,71 +865,77 @@ class OidcHandler:
# to be strings.
remote_user_id = str(remote_user_id)
- logger.info(
- "Looking for existing mapping for user %s:%s",
- self._auth_provider_id,
- remote_user_id,
- )
-
- registered_user_id = await self._datastore.get_user_by_external_id(
- self._auth_provider_id, remote_user_id,
+ # Older mapping providers don't accept the `failures` argument, so we
+ # try and detect support.
+ mapper_signature = inspect.signature(
+ self._user_mapping_provider.map_user_attributes
)
+ supports_failures = "failures" in mapper_signature.parameters
- if registered_user_id is not None:
- logger.info("Found existing mapping %s", registered_user_id)
- return registered_user_id
+ async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
+ """
+ Call the mapping provider to map the OIDC userinfo and token to user attributes.
- try:
- attributes = await self._user_mapping_provider.map_user_attributes(
- userinfo, token
- )
- except Exception as e:
- raise MappingException(
- "Could not extract user attributes from OIDC response: " + str(e)
- )
-
- logger.debug(
- "Retrieved user attributes from user mapping provider: %r", attributes
- )
+ This is backwards compatibility for abstraction for the SSO handler.
+ """
+ if supports_failures:
+ attributes = await self._user_mapping_provider.map_user_attributes(
+ userinfo, token, failures
+ )
+ else:
+ # If the mapping provider does not support processing failures,
+ # do not continually generate the same Matrix ID since it will
+ # continue to already be in use. Note that the error raised is
+ # arbitrary and will get turned into a MappingException.
+ if failures:
+ raise MappingException(
+ "Mapping provider does not support de-duplicating Matrix IDs"
+ )
- if not attributes["localpart"]:
- raise MappingException("localpart is empty")
+ attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore
+ userinfo, token
+ )
- localpart = map_username_to_mxid_localpart(attributes["localpart"])
+ return UserAttributes(**attributes)
- user_id = UserID(localpart, self._hostname).to_string()
- users = await self._datastore.get_users_by_id_case_insensitive(user_id)
- if users:
+ async def grandfather_existing_users() -> Optional[str]:
if self._allow_existing_users:
- if len(users) == 1:
- registered_user_id = next(iter(users))
- elif user_id in users:
- registered_user_id = user_id
- else:
- raise MappingException(
- "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
- user_id, list(users.keys())
+ # If allowing existing users we want to generate a single localpart
+ # and attempt to match it.
+ attributes = await oidc_response_to_user_attributes(failures=0)
+
+ user_id = UserID(attributes.localpart, self.server_name).to_string()
+ users = await self.store.get_users_by_id_case_insensitive(user_id)
+ if users:
+ # If an existing matrix ID is returned, then use it.
+ if len(users) == 1:
+ previously_registered_user_id = next(iter(users))
+ elif user_id in users:
+ previously_registered_user_id = user_id
+ else:
+ # Do not attempt to continue generating Matrix IDs.
+ raise MappingException(
+ "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
+ user_id, users
+ )
)
- )
- else:
- # This mxid is taken
- raise MappingException("mxid '{}' is already taken".format(user_id))
- else:
- # It's the first time this user is logging in and the mapped mxid was
- # not taken, register the user
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=attributes["display_name"],
- user_agent_ips=(user_agent, ip_address),
- )
- await self._datastore.record_user_external_id(
- self._auth_provider_id, remote_user_id, registered_user_id,
+
+ return previously_registered_user_id
+
+ return None
+
+ return await self._sso_handler.get_mxid_from_sso(
+ self._auth_provider_id,
+ remote_user_id,
+ user_agent,
+ ip_address,
+ oidc_response_to_user_attributes,
+ grandfather_existing_users,
)
- return registered_user_id
-UserAttribute = TypedDict(
- "UserAttribute", {"localpart": str, "display_name": Optional[str]}
+UserAttributeDict = TypedDict(
+ "UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
)
C = TypeVar("C")
@@ -991,13 +978,15 @@ class OidcMappingProvider(Generic[C]):
raise NotImplementedError()
async def map_user_attributes(
- self, userinfo: UserInfo, token: Token
- ) -> UserAttribute:
+ self, userinfo: UserInfo, token: Token, failures: int
+ ) -> UserAttributeDict:
"""Map a `UserInfo` object into user attributes.
Args:
userinfo: An object representing the user given by the OIDC provider
token: A dict with the tokens returned by the provider
+ failures: How many times a call to this function with this
+ UserInfo has resulted in a failure.
Returns:
A dict containing the ``localpart`` and (optionally) the ``display_name``
@@ -1097,10 +1086,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
return userinfo[self._config.subject_claim]
async def map_user_attributes(
- self, userinfo: UserInfo, token: Token
- ) -> UserAttribute:
+ self, userinfo: UserInfo, token: Token, failures: int
+ ) -> UserAttributeDict:
localpart = self._config.localpart_template.render(user=userinfo).strip()
+ # Ensure only valid characters are included in the MXID.
+ localpart = map_username_to_mxid_localpart(localpart)
+
+ # Append suffix integer if last call to this function failed to produce
+ # a usable mxid.
+ localpart += str(failures) if failures else ""
+
display_name = None # type: Optional[str]
if self._config.display_name_template is not None:
display_name = self._config.display_name_template.render(
@@ -1110,7 +1106,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
if display_name == "":
display_name = None
- return UserAttribute(localpart=localpart, display_name=display_name)
+ return UserAttributeDict(localpart=localpart, display_name=display_name)
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
extras = {} # type: Dict[str, str]
|