diff options
author | Richard van der Hoff <1389908+richvdh@users.noreply.github.com> | 2020-12-10 12:43:58 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-12-10 12:43:58 +0000 |
commit | c64002e1c1e95578528e96e3ae87738c4aea1d8a (patch) | |
tree | 898c68f35db5f64a5e1f6b314653796c45ddde84 /synapse/handlers/sso.py | |
parent | Fix buglet in DirectRenderJsonResource (#8897) (diff) | |
download | synapse-c64002e1c1e95578528e96e3ae87738c4aea1d8a.tar.xz |
Refactor `SsoHandler.get_mxid_from_sso` (#8900)
* Factor out _call_attribute_mapper and _register_mapped_user This is mostly an attempt to simplify `get_mxid_from_sso`. * Move mapping_lock down into SsoHandler.
Diffstat (limited to 'synapse/handlers/sso.py')
-rw-r--r-- | synapse/handlers/sso.py | 57 |
1 files changed, 42 insertions, 15 deletions
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index e24767b921..112a7d5b2c 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -22,6 +22,7 @@ from twisted.web.http import Request from synapse.api.errors import RedirectException from synapse.http.server import respond_with_html from synapse.types import UserID, contains_invalid_mxid_characters +from synapse.util.async_helpers import Linearizer if TYPE_CHECKING: from synapse.server import HomeServer @@ -54,6 +55,9 @@ class SsoHandler: self._error_template = hs.config.sso_error_template self._auth_handler = hs.get_auth_handler() + # a lock on the mappings + self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock()) + def render_error( self, request, error: str, error_description: Optional[str] = None ) -> None: @@ -172,24 +176,38 @@ class SsoHandler: to an additional page. (e.g. to prompt for more information) """ - # first of all, check if we already have a mapping for this user - previously_registered_user_id = await self.get_sso_user_by_remote_user_id( - auth_provider_id, remote_user_id, - ) - if previously_registered_user_id: - return previously_registered_user_id - - # Check for grandfathering of users. - if grandfather_existing_users: - previously_registered_user_id = await grandfather_existing_users() + # grab a lock while we try to find a mapping for this user. This seems... + # optimistic, especially for implementations that end up redirecting to + # interstitial pages. + with await self._mapping_lock.queue(auth_provider_id): + # first of all, check if we already have a mapping for this user + previously_registered_user_id = await self.get_sso_user_by_remote_user_id( + auth_provider_id, remote_user_id, + ) if previously_registered_user_id: - # Future logins should also match this user ID. - await self._store.record_user_external_id( - auth_provider_id, remote_user_id, previously_registered_user_id - ) return previously_registered_user_id - # Otherwise, generate a new user. + # Check for grandfathering of users. + if grandfather_existing_users: + previously_registered_user_id = await grandfather_existing_users() + if previously_registered_user_id: + # Future logins should also match this user ID. + await self._store.record_user_external_id( + auth_provider_id, remote_user_id, previously_registered_user_id + ) + return previously_registered_user_id + + # Otherwise, generate a new user. + attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper) + user_id = await self._register_mapped_user( + attributes, auth_provider_id, remote_user_id, user_agent, ip_address, + ) + return user_id + + async def _call_attribute_mapper( + self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], + ) -> UserAttributes: + """Call the attribute mapper function in a loop, until we get a unique userid""" for i in range(self._MAP_USERNAME_RETRIES): try: attributes = await sso_to_matrix_id_mapper(i) @@ -227,7 +245,16 @@ class SsoHandler: raise MappingException( "Unable to generate a Matrix ID from the SSO response" ) + return attributes + async def _register_mapped_user( + self, + attributes: UserAttributes, + auth_provider_id: str, + remote_user_id: str, + user_agent: str, + ip_address: str, + ) -> str: # Since the localpart is provided via a potentially untrusted module, # ensure the MXID is valid before registering. if contains_invalid_mxid_characters(attributes.localpart): |