diff --git a/changelog.d/8801.feature b/changelog.d/8801.feature
new file mode 100644
index 0000000000..77f7fe4e5d
--- /dev/null
+++ b/changelog.d/8801.feature
@@ -0,0 +1 @@
+Add support for re-trying generation of a localpart for OpenID Connect mapping providers.
diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md
index 707dd73978..dee53b5d40 100644
--- a/docs/sso_mapping_providers.md
+++ b/docs/sso_mapping_providers.md
@@ -63,13 +63,22 @@ A custom mapping provider must specify the following methods:
information from.
- This method must return a string, which is the unique identifier for the
user. Commonly the ``sub`` claim of the response.
-* `map_user_attributes(self, userinfo, token)`
+* `map_user_attributes(self, userinfo, token, failures)`
- This method must be async.
- Arguments:
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
information from.
- `token` - A dictionary which includes information necessary to make
further requests to the OpenID provider.
+ - `failures` - An `int` that represents the amount of times the returned
+ mxid localpart mapping has failed. This should be used
+ to create a deduplicated mxid localpart which should be
+ returned instead. For example, if this method returns
+ `john.doe` as the value of `localpart` in the returned
+ dict, and that is already taken on the homeserver, this
+ method will be called again with the same parameters but
+ with failures=1. The method should then return a different
+ `localpart` value, such as `john.doe1`.
- Returns a dictionary with two keys:
- localpart: A required string, used to generate the Matrix ID.
- displayname: An optional string, the display name for the user.
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 34de9109ea..78c4e94a9d 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
from urllib.parse import urlencode
@@ -35,15 +36,10 @@ from twisted.web.client import readBody
from synapse.config import ConfigError
from synapse.handlers._base import BaseHandler
-from synapse.handlers.sso import MappingException
+from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
-from synapse.types import (
- JsonDict,
- UserID,
- contains_invalid_mxid_characters,
- map_username_to_mxid_localpart,
-)
+from synapse.types import JsonDict, map_username_to_mxid_localpart
from synapse.util import json_decoder
if TYPE_CHECKING:
@@ -869,73 +865,51 @@ class OidcHandler(BaseHandler):
# to be strings.
remote_user_id = str(remote_user_id)
- # first of all, check if we already have a mapping for this user
- previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
- self._auth_provider_id, remote_user_id,
+ # Older mapping providers don't accept the `failures` argument, so we
+ # try and detect support.
+ mapper_signature = inspect.signature(
+ self._user_mapping_provider.map_user_attributes
)
- if previously_registered_user_id:
- return previously_registered_user_id
+ supports_failures = "failures" in mapper_signature.parameters
- # Otherwise, generate a new user.
- try:
- attributes = await self._user_mapping_provider.map_user_attributes(
- userinfo, token
- )
- except Exception as e:
- raise MappingException(
- "Could not extract user attributes from OIDC response: " + str(e)
- )
+ async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
+ """
+ Call the mapping provider to map the OIDC userinfo and token to user attributes.
- logger.debug(
- "Retrieved user attributes from user mapping provider: %r", attributes
- )
+ This is backwards compatibility for abstraction for the SSO handler.
+ """
+ if supports_failures:
+ attributes = await self._user_mapping_provider.map_user_attributes(
+ userinfo, token, failures
+ )
+ else:
+ # If the mapping provider does not support processing failures,
+ # do not continually generate the same Matrix ID since it will
+ # continue to already be in use. Note that the error raised is
+ # arbitrary and will get turned into a MappingException.
+ if failures:
+ raise RuntimeError(
+ "Mapping provider does not support de-duplicating Matrix IDs"
+ )
- localpart = attributes["localpart"]
- if not localpart:
- raise MappingException(
- "Error parsing OIDC response: OIDC mapping provider plugin "
- "did not return a localpart value"
- )
+ attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore
+ userinfo, token
+ )
- user_id = UserID(localpart, self.server_name).to_string()
- users = await self.store.get_users_by_id_case_insensitive(user_id)
- if users:
- if self._allow_existing_users:
- if len(users) == 1:
- registered_user_id = next(iter(users))
- elif user_id in users:
- registered_user_id = user_id
- else:
- raise MappingException(
- "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
- user_id, list(users.keys())
- )
- )
- else:
- # This mxid is taken
- raise MappingException("mxid '{}' is already taken".format(user_id))
- else:
- # Since the localpart is provided via a potentially untrusted module,
- # ensure the MXID is valid before registering.
- if contains_invalid_mxid_characters(localpart):
- raise MappingException("localpart is invalid: %s" % (localpart,))
-
- # It's the first time this user is logging in and the mapped mxid was
- # not taken, register the user
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=attributes["display_name"],
- user_agent_ips=[(user_agent, ip_address)],
- )
+ return UserAttributes(**attributes)
- await self.store.record_user_external_id(
- self._auth_provider_id, remote_user_id, registered_user_id,
+ return await self._sso_handler.get_mxid_from_sso(
+ self._auth_provider_id,
+ remote_user_id,
+ user_agent,
+ ip_address,
+ oidc_response_to_user_attributes,
+ self._allow_existing_users,
)
- return registered_user_id
-UserAttribute = TypedDict(
- "UserAttribute", {"localpart": str, "display_name": Optional[str]}
+UserAttributeDict = TypedDict(
+ "UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
)
C = TypeVar("C")
@@ -978,13 +952,15 @@ class OidcMappingProvider(Generic[C]):
raise NotImplementedError()
async def map_user_attributes(
- self, userinfo: UserInfo, token: Token
- ) -> UserAttribute:
+ self, userinfo: UserInfo, token: Token, failures: int
+ ) -> UserAttributeDict:
"""Map a `UserInfo` object into user attributes.
Args:
userinfo: An object representing the user given by the OIDC provider
token: A dict with the tokens returned by the provider
+ failures: How many times a call to this function with this
+ UserInfo has resulted in a failure.
Returns:
A dict containing the ``localpart`` and (optionally) the ``display_name``
@@ -1084,13 +1060,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
return userinfo[self._config.subject_claim]
async def map_user_attributes(
- self, userinfo: UserInfo, token: Token
- ) -> UserAttribute:
+ self, userinfo: UserInfo, token: Token, failures: int
+ ) -> UserAttributeDict:
localpart = self._config.localpart_template.render(user=userinfo).strip()
# Ensure only valid characters are included in the MXID.
localpart = map_username_to_mxid_localpart(localpart)
+ # Append suffix integer if last call to this function failed to produce
+ # a usable mxid.
+ localpart += str(failures) if failures else ""
+
display_name = None # type: Optional[str]
if self._config.display_name_template is not None:
display_name = self._config.display_name_template.render(
@@ -1100,7 +1080,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
if display_name == "":
display_name = None
- return UserAttribute(localpart=localpart, display_name=display_name)
+ return UserAttributeDict(localpart=localpart, display_name=display_name)
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
extras = {} # type: Dict[str, str]
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 37ab42f050..34db10ffe4 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -25,13 +25,12 @@ from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.handlers._base import BaseHandler
-from synapse.handlers.sso import MappingException
+from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
from synapse.types import (
UserID,
- contains_invalid_mxid_characters,
map_username_to_mxid_localpart,
mxid_localpart_allowed_characters,
)
@@ -250,14 +249,26 @@ class SamlHandler(BaseHandler):
"Failed to extract remote user id from SAML response"
)
- with (await self._mapping_lock.queue(self._auth_provider_id)):
- # first of all, check if we already have a mapping for this user
- previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
- self._auth_provider_id, remote_user_id,
+ async def saml_response_to_remapped_user_attributes(
+ failures: int,
+ ) -> UserAttributes:
+ """
+ Call the mapping provider to map a SAML response to user attributes and coerce the result into the standard form.
+
+ This is backwards compatibility for abstraction for the SSO handler.
+ """
+ # Call the mapping provider.
+ result = self._user_mapping_provider.saml_response_to_user_attributes(
+ saml2_auth, failures, client_redirect_url
+ )
+ # Remap some of the results.
+ return UserAttributes(
+ localpart=result.get("mxid_localpart"),
+ display_name=result.get("displayname"),
+ emails=result.get("emails"),
)
- if previously_registered_user_id:
- return previously_registered_user_id
+ with (await self._mapping_lock.queue(self._auth_provider_id)):
# backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid
if (
@@ -284,59 +295,13 @@ class SamlHandler(BaseHandler):
)
return registered_user_id
- # Map saml response to user attributes using the configured mapping provider
- for i in range(1000):
- attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
- saml2_auth, i, client_redirect_url=client_redirect_url,
- )
-
- logger.debug(
- "Retrieved SAML attributes from user mapping provider: %s "
- "(attempt %d)",
- attribute_dict,
- i,
- )
-
- localpart = attribute_dict.get("mxid_localpart")
- if not localpart:
- raise MappingException(
- "Error parsing SAML2 response: SAML mapping provider plugin "
- "did not return a mxid_localpart value"
- )
-
- displayname = attribute_dict.get("displayname")
- emails = attribute_dict.get("emails", [])
-
- # Check if this mxid already exists
- if not await self.store.get_users_by_id_case_insensitive(
- UserID(localpart, self.server_name).to_string()
- ):
- # This mxid is free
- break
- else:
- # Unable to generate a username in 1000 iterations
- # Break and return error to the user
- raise MappingException(
- "Unable to generate a Matrix ID from the SAML response"
- )
-
- # Since the localpart is provided via a potentially untrusted module,
- # ensure the MXID is valid before registering.
- if contains_invalid_mxid_characters(localpart):
- raise MappingException("localpart is invalid: %s" % (localpart,))
-
- logger.debug("Mapped SAML user to local part %s", localpart)
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=displayname,
- bind_emails=emails,
- user_agent_ips=[(user_agent, ip_address)],
- )
-
- await self.store.record_user_external_id(
- self._auth_provider_id, remote_user_id, registered_user_id
+ return await self._sso_handler.get_mxid_from_sso(
+ self._auth_provider_id,
+ remote_user_id,
+ user_agent,
+ ip_address,
+ saml_response_to_remapped_user_attributes,
)
- return registered_user_id
def expire_sessions(self):
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
@@ -451,11 +416,11 @@ class DefaultSamlMappingProvider:
)
# Use the configured mapper for this mxid_source
- base_mxid_localpart = self._mxid_mapper(mxid_source)
+ localpart = self._mxid_mapper(mxid_source)
# Append suffix integer if last call to this function failed to produce
- # a usable mxid
- localpart = base_mxid_localpart + (str(failures) if failures else "")
+ # a usable mxid.
+ localpart += str(failures) if failures else ""
# Retrieve the display name from the saml response
# If displayname is None, the mxid_localpart will be used instead
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index cf7cb7754a..d963082210 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -13,10 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
+
+import attr
from synapse.handlers._base import BaseHandler
from synapse.http.server import respond_with_html
+from synapse.types import UserID, contains_invalid_mxid_characters
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -29,9 +32,20 @@ class MappingException(Exception):
"""
+@attr.s
+class UserAttributes:
+ localpart = attr.ib(type=str)
+ display_name = attr.ib(type=Optional[str], default=None)
+ emails = attr.ib(type=List[str], default=attr.Factory(list))
+
+
class SsoHandler(BaseHandler):
+ # The number of attempts to ask the mapping provider for when generating an MXID.
+ _MAP_USERNAME_RETRIES = 1000
+
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
+ self._registration_handler = hs.get_registration_handler()
self._error_template = hs.config.sso_error_template
def render_error(
@@ -94,3 +108,142 @@ class SsoHandler(BaseHandler):
# No match.
return None
+
+ async def get_mxid_from_sso(
+ self,
+ auth_provider_id: str,
+ remote_user_id: str,
+ user_agent: str,
+ ip_address: str,
+ sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
+ allow_existing_users: bool = False,
+ ) -> str:
+ """
+ Given an SSO ID, retrieve the user ID for it and possibly register the user.
+
+ This first checks if the SSO ID has previously been linked to a matrix ID,
+ if it has that matrix ID is returned regardless of the current mapping
+ logic.
+
+ The mapping function is called (potentially multiple times) to generate
+ a localpart for the user.
+
+ If an unused localpart is generated, the user is registered from the
+ given user-agent and IP address and the SSO ID is linked to this matrix
+ ID for subsequent calls.
+
+ If allow_existing_users is true the mapping function is only called once
+ and results in:
+
+ 1. The use of a previously registered matrix ID. In this case, the
+ SSO ID is linked to the matrix ID. (Note it is possible that
+ other SSO IDs are linked to the same matrix ID.)
+ 2. An unused localpart, in which case the user is registered (as
+ discussed above).
+ 3. An error if the generated localpart matches multiple pre-existing
+ matrix IDs. Generally this should not happen.
+
+ Args:
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+ remote_user_id: The unique identifier from the SSO provider.
+ user_agent: The user agent of the client making the request.
+ ip_address: The IP address of the client making the request.
+ sso_to_matrix_id_mapper: A callable to generate the user attributes.
+ The only parameter is an integer which represents the amount of
+ times the returned mxid localpart mapping has failed.
+ allow_existing_users: True if the localpart returned from the
+ mapping provider can be linked to an existing matrix ID.
+
+ Returns:
+ The user ID associated with the SSO response.
+
+ Raises:
+ MappingException if there was a problem mapping the response to a user.
+ RedirectException: some mapping providers may raise this if they need
+ to redirect to an interstitial page.
+
+ """
+ # first of all, check if we already have a mapping for this user
+ previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
+ auth_provider_id, remote_user_id,
+ )
+ if previously_registered_user_id:
+ return previously_registered_user_id
+
+ # Otherwise, generate a new user.
+ for i in range(self._MAP_USERNAME_RETRIES):
+ try:
+ attributes = await sso_to_matrix_id_mapper(i)
+ except Exception as e:
+ raise MappingException(
+ "Could not extract user attributes from SSO response: " + str(e)
+ )
+
+ logger.debug(
+ "Retrieved user attributes from user mapping provider: %r (attempt %d)",
+ attributes,
+ i,
+ )
+
+ if not attributes.localpart:
+ raise MappingException(
+ "Error parsing SSO response: SSO mapping provider plugin "
+ "did not return a localpart value"
+ )
+
+ # Check if this mxid already exists
+ user_id = UserID(attributes.localpart, self.server_name).to_string()
+ users = await self.store.get_users_by_id_case_insensitive(user_id)
+ # Note, if allow_existing_users is true then the loop is guaranteed
+ # to end on the first iteration: either by matching an existing user,
+ # raising an error, or registering a new user. See the docstring for
+ # more in-depth an explanation.
+ if users and allow_existing_users:
+ # If an existing matrix ID is returned, then use it.
+ if len(users) == 1:
+ previously_registered_user_id = next(iter(users))
+ elif user_id in users:
+ previously_registered_user_id = user_id
+ else:
+ # Do not attempt to continue generating Matrix IDs.
+ raise MappingException(
+ "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
+ user_id, users
+ )
+ )
+
+ # Future logins should also match this user ID.
+ await self.store.record_user_external_id(
+ auth_provider_id, remote_user_id, previously_registered_user_id
+ )
+
+ return previously_registered_user_id
+
+ elif not users:
+ # This mxid is free
+ break
+ else:
+ # Unable to generate a username in 1000 iterations
+ # Break and return error to the user
+ raise MappingException(
+ "Unable to generate a Matrix ID from the SSO response"
+ )
+
+ # Since the localpart is provided via a potentially untrusted module,
+ # ensure the MXID is valid before registering.
+ if contains_invalid_mxid_characters(attributes.localpart):
+ raise MappingException("localpart is invalid: %s" % (attributes.localpart,))
+
+ logger.debug("Mapped SSO user to local part %s", attributes.localpart)
+ registered_user_id = await self._registration_handler.register_user(
+ localpart=attributes.localpart,
+ default_display_name=attributes.display_name,
+ bind_emails=attributes.emails,
+ user_agent_ips=[(user_agent, ip_address)],
+ )
+
+ await self.store.record_user_external_id(
+ auth_provider_id, remote_user_id, registered_user_id
+ )
+ return registered_user_id
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index b4fa02acc4..e880d32be6 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -89,6 +89,14 @@ class TestMappingProviderExtra(TestMappingProvider):
return {"phone": userinfo["phone"]}
+class TestMappingProviderFailures(TestMappingProvider):
+ async def map_user_attributes(self, userinfo, token, failures):
+ return {
+ "localpart": userinfo["username"] + (str(failures) if failures else ""),
+ "display_name": None,
+ }
+
+
def simple_async_mock(return_value=None, raises=None):
# AsyncMock is not available in python3.5, this mimics part of its behaviour
async def cb(*args, **kwargs):
@@ -152,6 +160,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.render_error = Mock(return_value=None)
self.handler._sso_handler.render_error = self.render_error
+ # Reduce the number of attempts when generating MXIDs.
+ self.handler._sso_handler._MAP_USERNAME_RETRIES = 3
+
return hs
def metadata_edit(self, values):
@@ -693,7 +704,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
),
MappingException,
)
- self.assertEqual(str(e.value), "mxid '@test_user_3:test' is already taken")
+ self.assertEqual(
+ str(e.value),
+ "Could not extract user attributes from SSO response: Mapping provider does not support de-duplicating Matrix IDs",
+ )
@override_config({"oidc_config": {"allow_existing_users": True}})
def test_map_userinfo_to_existing_user(self):
@@ -703,6 +717,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(
store.register_user(user_id=user.to_string(), password_hash=None)
)
+
+ # Map a user via SSO.
userinfo = {
"sub": "test",
"username": "test_user",
@@ -715,6 +731,23 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
self.assertEqual(mxid, "@test_user:test")
+ # Note that a second SSO user can be mapped to the same Matrix ID. (This
+ # requires a unique sub, but something that maps to the same matrix ID,
+ # in this case we'll just use the same username. A more realistic example
+ # would be subs which are email addresses, and mapping from the localpart
+ # of the email, e.g. bob@foo.com and bob@bar.com -> @bob:test.)
+ userinfo = {
+ "sub": "test1",
+ "username": "test_user",
+ }
+ token = {}
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user:test")
+
# Register some non-exact matching cases.
user2 = UserID.from_string("@TEST_user_2:test")
self.get_success(
@@ -762,6 +795,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "föö",
}
token = {}
+
e = self.get_failure(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
@@ -769,3 +803,55 @@ class OidcHandlerTestCase(HomeserverTestCase):
MappingException,
)
self.assertEqual(str(e.value), "localpart is invalid: föö")
+
+ @override_config(
+ {
+ "oidc_config": {
+ "user_mapping_provider": {
+ "module": __name__ + ".TestMappingProviderFailures"
+ }
+ }
+ }
+ )
+ def test_map_userinfo_to_user_retries(self):
+ """The mapping provider can retry generating an MXID if the MXID is already in use."""
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.register_user(user_id="@test_user:test", password_hash=None)
+ )
+ userinfo = {
+ "sub": "test",
+ "username": "test_user",
+ }
+ token = {}
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ # test_user is already taken, so test_user1 gets registered instead.
+ self.assertEqual(mxid, "@test_user1:test")
+
+ # Register all of the potential users for a particular username.
+ self.get_success(
+ store.register_user(user_id="@tester:test", password_hash=None)
+ )
+ for i in range(1, 3):
+ self.get_success(
+ store.register_user(user_id="@tester%d:test" % i, password_hash=None)
+ )
+
+ # Now attempt to map to a username, this will fail since all potential usernames are taken.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ }
+ e = self.get_failure(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ ),
+ MappingException,
+ )
+ self.assertEqual(
+ str(e.value), "Unable to generate a Matrix ID from the SSO response"
+ )
|