diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 47ad96f97e..e24767b921 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -17,8 +17,9 @@ from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
import attr
+from twisted.web.http import Request
+
from synapse.api.errors import RedirectException
-from synapse.handlers._base import BaseHandler
from synapse.http.server import respond_with_html
from synapse.types import UserID, contains_invalid_mxid_characters
@@ -42,14 +43,16 @@ class UserAttributes:
emails = attr.ib(type=List[str], default=attr.Factory(list))
-class SsoHandler(BaseHandler):
+class SsoHandler:
# The number of attempts to ask the mapping provider for when generating an MXID.
_MAP_USERNAME_RETRIES = 1000
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self._store = hs.get_datastore()
+ self._server_name = hs.hostname
self._registration_handler = hs.get_registration_handler()
self._error_template = hs.config.sso_error_template
+ self._auth_handler = hs.get_auth_handler()
def render_error(
self, request, error: str, error_description: Optional[str] = None
@@ -95,7 +98,7 @@ class SsoHandler(BaseHandler):
)
# Check if we already have a mapping for this user.
- previously_registered_user_id = await self.store.get_user_by_external_id(
+ previously_registered_user_id = await self._store.get_user_by_external_id(
auth_provider_id, remote_user_id,
)
@@ -181,7 +184,7 @@ class SsoHandler(BaseHandler):
previously_registered_user_id = await grandfather_existing_users()
if previously_registered_user_id:
# Future logins should also match this user ID.
- await self.store.record_user_external_id(
+ await self._store.record_user_external_id(
auth_provider_id, remote_user_id, previously_registered_user_id
)
return previously_registered_user_id
@@ -214,8 +217,8 @@ class SsoHandler(BaseHandler):
)
# Check if this mxid already exists
- user_id = UserID(attributes.localpart, self.server_name).to_string()
- if not await self.store.get_users_by_id_case_insensitive(user_id):
+ user_id = UserID(attributes.localpart, self._server_name).to_string()
+ if not await self._store.get_users_by_id_case_insensitive(user_id):
# This mxid is free
break
else:
@@ -238,7 +241,47 @@ class SsoHandler(BaseHandler):
user_agent_ips=[(user_agent, ip_address)],
)
- await self.store.record_user_external_id(
+ await self._store.record_user_external_id(
auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id
+
+ async def complete_sso_ui_auth_request(
+ self,
+ auth_provider_id: str,
+ remote_user_id: str,
+ ui_auth_session_id: str,
+ request: Request,
+ ) -> None:
+ """
+ Given an SSO ID, retrieve the user ID for it and complete UIA.
+
+ Note that this requires that the user is mapped in the "user_external_ids"
+ table. This will be the case if they have ever logged in via SAML or OIDC in
+ recentish synapse versions, but may not be for older users.
+
+ Args:
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+ remote_user_id: The unique identifier from the SSO provider.
+ ui_auth_session_id: The ID of the user-interactive auth session.
+ request: The request to complete.
+ """
+
+ user_id = await self.get_sso_user_by_remote_user_id(
+ auth_provider_id, remote_user_id,
+ )
+
+ if not user_id:
+ logger.warning(
+ "Remote user %s/%s has not previously logged in here: UIA will fail",
+ auth_provider_id,
+ remote_user_id,
+ )
+ # Let the UIA flow handle this the same as if they presented creds for a
+ # different user.
+ user_id = ""
+
+ await self._auth_handler.complete_sso_ui_auth(
+ user_id, ui_auth_session_id, request
+ )
|