diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index bb81c0e81d..d29b066a56 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -32,6 +32,10 @@ logger = logging.getLogger(__name__)
class BaseHandler:
"""
Common base class for the event handlers.
+
+ Deprecated: new code should not use this. Instead, Handler classes should define the
+ fields they actually need. The utility methods should either be factored out to
+ standalone helper functions, or to different Handler classes.
"""
def __init__(self, hs: "HomeServer"):
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c7dc07008a..afae6d3272 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -36,6 +36,8 @@ import attr
import bcrypt
import pymacaroons
+from twisted.web.http import Request
+
from synapse.api.constants import LoginType
from synapse.api.errors import (
AuthError,
@@ -193,9 +195,7 @@ class AuthHandler(BaseHandler):
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled
- self._sso_enabled = (
- hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
- )
+ self._password_localdb_enabled = hs.config.password_localdb_enabled
# we keep this as a list despite the O(N^2) implication so that we can
# keep PASSWORD first and avoid confusing clients which pick the first
@@ -205,7 +205,7 @@ class AuthHandler(BaseHandler):
# start out by assuming PASSWORD is enabled; we will remove it later if not.
login_types = []
- if hs.config.password_localdb_enabled:
+ if self._password_localdb_enabled:
login_types.append(LoginType.PASSWORD)
for provider in self.password_providers:
@@ -219,14 +219,6 @@ class AuthHandler(BaseHandler):
self._supported_login_types = login_types
- # Login types and UI Auth types have a heavy overlap, but are not
- # necessarily identical. Login types have SSO (and other login types)
- # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
- ui_auth_types = login_types.copy()
- if self._sso_enabled:
- ui_auth_types.append(LoginType.SSO)
- self._supported_ui_auth_types = ui_auth_types
-
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
self._failed_uia_attempts_ratelimiter = Ratelimiter(
@@ -339,7 +331,10 @@ class AuthHandler(BaseHandler):
self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
# build a list of supported flows
- flows = [[login_type] for login_type in self._supported_ui_auth_types]
+ supported_ui_auth_types = await self._get_available_ui_auth_types(
+ requester.user
+ )
+ flows = [[login_type] for login_type in supported_ui_auth_types]
try:
result, params, session_id = await self.check_ui_auth(
@@ -351,7 +346,7 @@ class AuthHandler(BaseHandler):
raise
# find the completed login type
- for login_type in self._supported_ui_auth_types:
+ for login_type in supported_ui_auth_types:
if login_type not in result:
continue
@@ -367,6 +362,41 @@ class AuthHandler(BaseHandler):
return params, session_id
+ async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
+ """Get a list of the authentication types this user can use
+ """
+
+ ui_auth_types = set()
+
+ # if the HS supports password auth, and the user has a non-null password, we
+ # support password auth
+ if self._password_localdb_enabled and self._password_enabled:
+ lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
+ if lookupres:
+ _, password_hash = lookupres
+ if password_hash:
+ ui_auth_types.add(LoginType.PASSWORD)
+
+ # also allow auth from password providers
+ for provider in self.password_providers:
+ for t in provider.get_supported_login_types().keys():
+ if t == LoginType.PASSWORD and not self._password_enabled:
+ continue
+ ui_auth_types.add(t)
+
+ # if sso is enabled, allow the user to log in via SSO iff they have a mapping
+ # from sso to mxid.
+ if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
+ if await self.store.get_external_ids_by_user(user.to_string()):
+ ui_auth_types.add(LoginType.SSO)
+
+ # Our CAS impl does not (yet) correctly register users in user_external_ids,
+ # so always offer that if it's available.
+ if self.hs.config.cas.cas_enabled:
+ ui_auth_types.add(LoginType.SSO)
+
+ return ui_auth_types
+
def get_enabled_auth_types(self):
"""Return the enabled user-interactive authentication types
@@ -1029,7 +1059,7 @@ class AuthHandler(BaseHandler):
if result:
return result
- if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
+ if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
known_login_type = True
# we've already checked that there is a (valid) password field
@@ -1303,15 +1333,14 @@ class AuthHandler(BaseHandler):
)
async def complete_sso_ui_auth(
- self, registered_user_id: str, session_id: str, request: SynapseRequest,
+ self, registered_user_id: str, session_id: str, request: Request,
):
"""Having figured out a mxid for this user, complete the HTTP request
Args:
registered_user_id: The registered user ID to complete SSO login for.
+ session_id: The ID of the user-interactive auth session.
request: The request to complete.
- client_redirect_url: The URL to which to redirect the user at the end of the
- process.
"""
# Mark the stage of the authentication as successful.
# Save the user who authenticated with SSO, this will be used to ensure
@@ -1327,7 +1356,7 @@ class AuthHandler(BaseHandler):
async def complete_sso_login(
self,
registered_user_id: str,
- request: SynapseRequest,
+ request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
):
@@ -1355,7 +1384,7 @@ class AuthHandler(BaseHandler):
def _complete_sso_login(
self,
registered_user_id: str,
- request: SynapseRequest,
+ request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
):
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index b9799090f7..df82e60b33 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -140,7 +140,7 @@ class FederationHandler(BaseHandler):
self._message_handler = hs.get_message_handler()
self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config
- self.http_client = hs.get_simple_http_client()
+ self.http_client = hs.get_proxied_blacklisted_http_client()
self._instance_name = hs.get_instance_name()
self._replication = hs.get_replication_data_handler()
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 9b3c6b4551..7301c24710 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -46,13 +46,13 @@ class IdentityHandler(BaseHandler):
def __init__(self, hs):
super().__init__(hs)
+ # An HTTP client for contacting trusted URLs.
self.http_client = SimpleHttpClient(hs)
- # We create a blacklisting instance of SimpleHttpClient for contacting identity
- # servers specified by clients
+ # An HTTP client for contacting identity servers specified by clients.
self.blacklisting_http_client = SimpleHttpClient(
hs, ip_blacklist=hs.config.federation_ip_range_blacklist
)
- self.federation_http_client = hs.get_http_client()
+ self.federation_http_client = hs.get_federation_http_client()
self.hs = hs
async def threepid_from_creds(
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index c605f7082a..f626117f76 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -674,6 +674,21 @@ class OidcHandler(BaseHandler):
self._sso_handler.render_error(request, "invalid_token", str(e))
return
+ # first check if we're doing a UIA
+ if ui_auth_session_id:
+ try:
+ remote_user_id = self._remote_id_from_userinfo(userinfo)
+ except Exception as e:
+ logger.exception("Could not extract remote user id")
+ self._sso_handler.render_error(request, "mapping_error", str(e))
+ return
+
+ return await self._sso_handler.complete_sso_ui_auth_request(
+ self._auth_provider_id, remote_user_id, ui_auth_session_id, request
+ )
+
+ # otherwise, it's a login
+
# Pull out the user-agent and IP from the request.
user_agent = request.get_user_agent("")
ip_address = self.hs.get_ip_from_request(request)
@@ -698,14 +713,9 @@ class OidcHandler(BaseHandler):
extra_attributes = await get_extra_attributes(userinfo, token)
# and finally complete the login
- if ui_auth_session_id:
- await self._auth_handler.complete_sso_ui_auth(
- user_id, ui_auth_session_id, request
- )
- else:
- await self._auth_handler.complete_sso_login(
- user_id, request, client_redirect_url, extra_attributes
- )
+ await self._auth_handler.complete_sso_login(
+ user_id, request, client_redirect_url, extra_attributes
+ )
def _generate_oidc_session_token(
self,
@@ -856,14 +866,11 @@ class OidcHandler(BaseHandler):
The mxid of the user
"""
try:
- remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
+ remote_user_id = self._remote_id_from_userinfo(userinfo)
except Exception as e:
raise MappingException(
"Failed to extract subject from OIDC response: %s" % (e,)
)
- # Some OIDC providers use integer IDs, but Synapse expects external IDs
- # to be strings.
- remote_user_id = str(remote_user_id)
# Older mapping providers don't accept the `failures` argument, so we
# try and detect support.
@@ -933,6 +940,19 @@ class OidcHandler(BaseHandler):
grandfather_existing_users,
)
+ def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
+ """Extract the unique remote id from an OIDC UserInfo block
+
+ Args:
+ userinfo: An object representing the user given by the OIDC provider
+ Returns:
+ remote user id
+ """
+ remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
+ # Some OIDC providers use integer IDs, but Synapse expects external IDs
+ # to be strings.
+ return str(remote_user_id)
+
UserAttributeDict = TypedDict(
"UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 76d4169fe2..5846f08609 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -183,6 +183,24 @@ class SamlHandler(BaseHandler):
saml2_auth.in_response_to, None
)
+ # first check if we're doing a UIA
+ if current_session and current_session.ui_auth_session_id:
+ try:
+ remote_user_id = self._remote_id_from_saml_response(saml2_auth, None)
+ except MappingException as e:
+ logger.exception("Failed to extract remote user id from SAML response")
+ self._sso_handler.render_error(request, "mapping_error", str(e))
+ return
+
+ return await self._sso_handler.complete_sso_ui_auth_request(
+ self._auth_provider_id,
+ remote_user_id,
+ current_session.ui_auth_session_id,
+ request,
+ )
+
+ # otherwise, we're handling a login request.
+
# Ensure that the attributes of the logged in user meet the required
# attributes.
for requirement in self._saml2_attribute_requirements:
@@ -206,14 +224,7 @@ class SamlHandler(BaseHandler):
self._sso_handler.render_error(request, "mapping_error", str(e))
return
- # Complete the interactive auth session or the login.
- if current_session and current_session.ui_auth_session_id:
- await self._auth_handler.complete_sso_ui_auth(
- user_id, current_session.ui_auth_session_id, request
- )
-
- else:
- await self._auth_handler.complete_sso_login(user_id, request, relay_state)
+ await self._auth_handler.complete_sso_login(user_id, request, relay_state)
async def _map_saml_response_to_user(
self,
@@ -239,16 +250,10 @@ class SamlHandler(BaseHandler):
RedirectException: some mapping providers may raise this if they need
to redirect to an interstitial page.
"""
-
- remote_user_id = self._user_mapping_provider.get_remote_user_id(
+ remote_user_id = self._remote_id_from_saml_response(
saml2_auth, client_redirect_url
)
- if not remote_user_id:
- raise MappingException(
- "Failed to extract remote user id from SAML response"
- )
-
async def saml_response_to_remapped_user_attributes(
failures: int,
) -> UserAttributes:
@@ -304,6 +309,35 @@ class SamlHandler(BaseHandler):
grandfather_existing_users,
)
+ def _remote_id_from_saml_response(
+ self,
+ saml2_auth: saml2.response.AuthnResponse,
+ client_redirect_url: Optional[str],
+ ) -> str:
+ """Extract the unique remote id from a SAML2 AuthnResponse
+
+ Args:
+ saml2_auth: The parsed SAML2 response.
+ client_redirect_url: The redirect URL passed in by the client.
+ Returns:
+ remote user id
+
+ Raises:
+ MappingException if there was an error extracting the user id
+ """
+ # It's not obvious why we need to pass in the redirect URI to the mapping
+ # provider, but we do :/
+ remote_user_id = self._user_mapping_provider.get_remote_user_id(
+ saml2_auth, client_redirect_url
+ )
+
+ if not remote_user_id:
+ raise MappingException(
+ "Failed to extract remote user id from SAML response"
+ )
+
+ return remote_user_id
+
def expire_sessions(self):
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set()
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 47ad96f97e..e24767b921 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -17,8 +17,9 @@ from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
import attr
+from twisted.web.http import Request
+
from synapse.api.errors import RedirectException
-from synapse.handlers._base import BaseHandler
from synapse.http.server import respond_with_html
from synapse.types import UserID, contains_invalid_mxid_characters
@@ -42,14 +43,16 @@ class UserAttributes:
emails = attr.ib(type=List[str], default=attr.Factory(list))
-class SsoHandler(BaseHandler):
+class SsoHandler:
# The number of attempts to ask the mapping provider for when generating an MXID.
_MAP_USERNAME_RETRIES = 1000
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self._store = hs.get_datastore()
+ self._server_name = hs.hostname
self._registration_handler = hs.get_registration_handler()
self._error_template = hs.config.sso_error_template
+ self._auth_handler = hs.get_auth_handler()
def render_error(
self, request, error: str, error_description: Optional[str] = None
@@ -95,7 +98,7 @@ class SsoHandler(BaseHandler):
)
# Check if we already have a mapping for this user.
- previously_registered_user_id = await self.store.get_user_by_external_id(
+ previously_registered_user_id = await self._store.get_user_by_external_id(
auth_provider_id, remote_user_id,
)
@@ -181,7 +184,7 @@ class SsoHandler(BaseHandler):
previously_registered_user_id = await grandfather_existing_users()
if previously_registered_user_id:
# Future logins should also match this user ID.
- await self.store.record_user_external_id(
+ await self._store.record_user_external_id(
auth_provider_id, remote_user_id, previously_registered_user_id
)
return previously_registered_user_id
@@ -214,8 +217,8 @@ class SsoHandler(BaseHandler):
)
# Check if this mxid already exists
- user_id = UserID(attributes.localpart, self.server_name).to_string()
- if not await self.store.get_users_by_id_case_insensitive(user_id):
+ user_id = UserID(attributes.localpart, self._server_name).to_string()
+ if not await self._store.get_users_by_id_case_insensitive(user_id):
# This mxid is free
break
else:
@@ -238,7 +241,47 @@ class SsoHandler(BaseHandler):
user_agent_ips=[(user_agent, ip_address)],
)
- await self.store.record_user_external_id(
+ await self._store.record_user_external_id(
auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id
+
+ async def complete_sso_ui_auth_request(
+ self,
+ auth_provider_id: str,
+ remote_user_id: str,
+ ui_auth_session_id: str,
+ request: Request,
+ ) -> None:
+ """
+ Given an SSO ID, retrieve the user ID for it and complete UIA.
+
+ Note that this requires that the user is mapped in the "user_external_ids"
+ table. This will be the case if they have ever logged in via SAML or OIDC in
+ recentish synapse versions, but may not be for older users.
+
+ Args:
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+ remote_user_id: The unique identifier from the SSO provider.
+ ui_auth_session_id: The ID of the user-interactive auth session.
+ request: The request to complete.
+ """
+
+ user_id = await self.get_sso_user_by_remote_user_id(
+ auth_provider_id, remote_user_id,
+ )
+
+ if not user_id:
+ logger.warning(
+ "Remote user %s/%s has not previously logged in here: UIA will fail",
+ auth_provider_id,
+ remote_user_id,
+ )
+ # Let the UIA flow handle this the same as if they presented creds for a
+ # different user.
+ user_id = ""
+
+ await self._auth_handler.complete_sso_ui_auth(
+ user_id, ui_auth_session_id, request
+ )
|