diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 47ad96f97e..112a7d5b2c 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -17,10 +17,12 @@ from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
import attr
+from twisted.web.http import Request
+
from synapse.api.errors import RedirectException
-from synapse.handlers._base import BaseHandler
from synapse.http.server import respond_with_html
from synapse.types import UserID, contains_invalid_mxid_characters
+from synapse.util.async_helpers import Linearizer
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -42,14 +44,19 @@ class UserAttributes:
emails = attr.ib(type=List[str], default=attr.Factory(list))
-class SsoHandler(BaseHandler):
+class SsoHandler:
# The number of attempts to ask the mapping provider for when generating an MXID.
_MAP_USERNAME_RETRIES = 1000
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self._store = hs.get_datastore()
+ self._server_name = hs.hostname
self._registration_handler = hs.get_registration_handler()
self._error_template = hs.config.sso_error_template
+ self._auth_handler = hs.get_auth_handler()
+
+ # a lock on the mappings
+ self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
def render_error(
self, request, error: str, error_description: Optional[str] = None
@@ -95,7 +102,7 @@ class SsoHandler(BaseHandler):
)
# Check if we already have a mapping for this user.
- previously_registered_user_id = await self.store.get_user_by_external_id(
+ previously_registered_user_id = await self._store.get_user_by_external_id(
auth_provider_id, remote_user_id,
)
@@ -169,24 +176,38 @@ class SsoHandler(BaseHandler):
to an additional page. (e.g. to prompt for more information)
"""
- # first of all, check if we already have a mapping for this user
- previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
- auth_provider_id, remote_user_id,
- )
- if previously_registered_user_id:
- return previously_registered_user_id
-
- # Check for grandfathering of users.
- if grandfather_existing_users:
- previously_registered_user_id = await grandfather_existing_users()
+ # grab a lock while we try to find a mapping for this user. This seems...
+ # optimistic, especially for implementations that end up redirecting to
+ # interstitial pages.
+ with await self._mapping_lock.queue(auth_provider_id):
+ # first of all, check if we already have a mapping for this user
+ previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
+ auth_provider_id, remote_user_id,
+ )
if previously_registered_user_id:
- # Future logins should also match this user ID.
- await self.store.record_user_external_id(
- auth_provider_id, remote_user_id, previously_registered_user_id
- )
return previously_registered_user_id
- # Otherwise, generate a new user.
+ # Check for grandfathering of users.
+ if grandfather_existing_users:
+ previously_registered_user_id = await grandfather_existing_users()
+ if previously_registered_user_id:
+ # Future logins should also match this user ID.
+ await self._store.record_user_external_id(
+ auth_provider_id, remote_user_id, previously_registered_user_id
+ )
+ return previously_registered_user_id
+
+ # Otherwise, generate a new user.
+ attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
+ user_id = await self._register_mapped_user(
+ attributes, auth_provider_id, remote_user_id, user_agent, ip_address,
+ )
+ return user_id
+
+ async def _call_attribute_mapper(
+ self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
+ ) -> UserAttributes:
+ """Call the attribute mapper function in a loop, until we get a unique userid"""
for i in range(self._MAP_USERNAME_RETRIES):
try:
attributes = await sso_to_matrix_id_mapper(i)
@@ -214,8 +235,8 @@ class SsoHandler(BaseHandler):
)
# Check if this mxid already exists
- user_id = UserID(attributes.localpart, self.server_name).to_string()
- if not await self.store.get_users_by_id_case_insensitive(user_id):
+ user_id = UserID(attributes.localpart, self._server_name).to_string()
+ if not await self._store.get_users_by_id_case_insensitive(user_id):
# This mxid is free
break
else:
@@ -224,7 +245,16 @@ class SsoHandler(BaseHandler):
raise MappingException(
"Unable to generate a Matrix ID from the SSO response"
)
+ return attributes
+ async def _register_mapped_user(
+ self,
+ attributes: UserAttributes,
+ auth_provider_id: str,
+ remote_user_id: str,
+ user_agent: str,
+ ip_address: str,
+ ) -> str:
# Since the localpart is provided via a potentially untrusted module,
# ensure the MXID is valid before registering.
if contains_invalid_mxid_characters(attributes.localpart):
@@ -238,7 +268,47 @@ class SsoHandler(BaseHandler):
user_agent_ips=[(user_agent, ip_address)],
)
- await self.store.record_user_external_id(
+ await self._store.record_user_external_id(
auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id
+
+ async def complete_sso_ui_auth_request(
+ self,
+ auth_provider_id: str,
+ remote_user_id: str,
+ ui_auth_session_id: str,
+ request: Request,
+ ) -> None:
+ """
+ Given an SSO ID, retrieve the user ID for it and complete UIA.
+
+ Note that this requires that the user is mapped in the "user_external_ids"
+ table. This will be the case if they have ever logged in via SAML or OIDC in
+ recentish synapse versions, but may not be for older users.
+
+ Args:
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+ remote_user_id: The unique identifier from the SSO provider.
+ ui_auth_session_id: The ID of the user-interactive auth session.
+ request: The request to complete.
+ """
+
+ user_id = await self.get_sso_user_by_remote_user_id(
+ auth_provider_id, remote_user_id,
+ )
+
+ if not user_id:
+ logger.warning(
+ "Remote user %s/%s has not previously logged in here: UIA will fail",
+ auth_provider_id,
+ remote_user_id,
+ )
+ # Let the UIA flow handle this the same as if they presented creds for a
+ # different user.
+ user_id = ""
+
+ await self._auth_handler.complete_sso_ui_auth(
+ user_id, ui_auth_session_id, request
+ )
|