diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index fd6c5e9ea8..34db10ffe4 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -24,7 +24,8 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement
-from synapse.http.server import respond_with_html
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
@@ -37,15 +38,11 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
- import synapse.server
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-class MappingException(Exception):
- """Used to catch errors when mapping the SAML2 response to a user."""
-
-
@attr.s(slots=True)
class Saml2SessionData:
"""Data we track about SAML2 sessions"""
@@ -57,17 +54,14 @@ class Saml2SessionData:
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
-class SamlHandler:
- def __init__(self, hs: "synapse.server.HomeServer"):
- self.hs = hs
+class SamlHandler(BaseHandler):
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
- self._auth = hs.get_auth()
+ self._saml_idp_entityid = hs.config.saml2_idp_entityid
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
- self._clock = hs.get_clock()
- self._datastore = hs.get_datastore()
- self._hostname = hs.hostname
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
self._grandfathered_mxid_source_attribute = (
hs.config.saml2_grandfathered_mxid_source_attribute
@@ -88,26 +82,9 @@ class SamlHandler:
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
# a lock on the mappings
- self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
-
- def _render_error(
- self, request, error: str, error_description: Optional[str] = None
- ) -> None:
- """Render the error template and respond to the request with it.
-
- This is used to show errors to the user. The template of this page can
- be found under `synapse/res/templates/sso_error.html`.
+ self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
- Args:
- request: The incoming request from the browser.
- We'll respond with an HTML page describing the error.
- error: A technical identifier for this error.
- error_description: A human-readable description of the error.
- """
- html = self._error_template.render(
- error=error, error_description=error_description
- )
- respond_with_html(request, 400, html)
+ self._sso_handler = hs.get_sso_handler()
def handle_redirect_request(
self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
@@ -124,13 +101,13 @@ class SamlHandler:
URL to redirect to
"""
reqid, info = self._saml_client.prepare_for_authenticate(
- relay_state=client_redirect_url
+ entityid=self._saml_idp_entityid, relay_state=client_redirect_url
)
# Since SAML sessions timeout it is useful to log when they were created.
logger.info("Initiating a new SAML session: %s" % (reqid,))
- now = self._clock.time_msec()
+ now = self.clock.time_msec()
self._outstanding_requests_dict[reqid] = Saml2SessionData(
creation_time=now, ui_auth_session_id=ui_auth_session_id,
)
@@ -171,12 +148,12 @@ class SamlHandler:
# in the (user-visible) exception message, so let's log the exception here
# so we can track down the session IDs later.
logger.warning(str(e))
- self._render_error(
+ self._sso_handler.render_error(
request, "unsolicited_response", "Unexpected SAML2 login."
)
return
except Exception as e:
- self._render_error(
+ self._sso_handler.render_error(
request,
"invalid_response",
"Unable to parse SAML2 response: %s." % (e,),
@@ -184,7 +161,7 @@ class SamlHandler:
return
if saml2_auth.not_signed:
- self._render_error(
+ self._sso_handler.render_error(
request, "unsigned_respond", "SAML2 response was not signed."
)
return
@@ -210,7 +187,7 @@ class SamlHandler:
# attributes.
for requirement in self._saml2_attribute_requirements:
if not _check_attribute_requirement(saml2_auth.ava, requirement):
- self._render_error(
+ self._sso_handler.render_error(
request, "unauthorised", "You are not authorised to log in here."
)
return
@@ -226,7 +203,7 @@ class SamlHandler:
)
except MappingException as e:
logger.exception("Could not map user")
- self._render_error(request, "mapping_error", str(e))
+ self._sso_handler.render_error(request, "mapping_error", str(e))
return
# Complete the interactive auth session or the login.
@@ -272,20 +249,26 @@ class SamlHandler:
"Failed to extract remote user id from SAML response"
)
- with (await self._mapping_lock.queue(self._auth_provider_id)):
- # first of all, check if we already have a mapping for this user
- logger.info(
- "Looking for existing mapping for user %s:%s",
- self._auth_provider_id,
- remote_user_id,
+ async def saml_response_to_remapped_user_attributes(
+ failures: int,
+ ) -> UserAttributes:
+ """
+ Call the mapping provider to map a SAML response to user attributes and coerce the result into the standard form.
+
+ This is backwards compatibility for abstraction for the SSO handler.
+ """
+ # Call the mapping provider.
+ result = self._user_mapping_provider.saml_response_to_user_attributes(
+ saml2_auth, failures, client_redirect_url
)
- registered_user_id = await self._datastore.get_user_by_external_id(
- self._auth_provider_id, remote_user_id
+ # Remap some of the results.
+ return UserAttributes(
+ localpart=result.get("mxid_localpart"),
+ display_name=result.get("displayname"),
+ emails=result.get("emails"),
)
- if registered_user_id is not None:
- logger.info("Found existing mapping %s", registered_user_id)
- return registered_user_id
+ with (await self._mapping_lock.queue(self._auth_provider_id)):
# backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid
if (
@@ -294,75 +277,34 @@ class SamlHandler:
):
attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
user_id = UserID(
- map_username_to_mxid_localpart(attrval), self._hostname
+ map_username_to_mxid_localpart(attrval), self.server_name
).to_string()
- logger.info(
+
+ logger.debug(
"Looking for existing account based on mapped %s %s",
self._grandfathered_mxid_source_attribute,
user_id,
)
- users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+ users = await self.store.get_users_by_id_case_insensitive(user_id)
if users:
registered_user_id = list(users.keys())[0]
logger.info("Grandfathering mapping to %s", registered_user_id)
- await self._datastore.record_user_external_id(
+ await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id
- # Map saml response to user attributes using the configured mapping provider
- for i in range(1000):
- attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
- saml2_auth, i, client_redirect_url=client_redirect_url,
- )
-
- logger.debug(
- "Retrieved SAML attributes from user mapping provider: %s "
- "(attempt %d)",
- attribute_dict,
- i,
- )
-
- localpart = attribute_dict.get("mxid_localpart")
- if not localpart:
- raise MappingException(
- "Error parsing SAML2 response: SAML mapping provider plugin "
- "did not return a mxid_localpart value"
- )
-
- displayname = attribute_dict.get("displayname")
- emails = attribute_dict.get("emails", [])
-
- # Check if this mxid already exists
- if not await self._datastore.get_users_by_id_case_insensitive(
- UserID(localpart, self._hostname).to_string()
- ):
- # This mxid is free
- break
- else:
- # Unable to generate a username in 1000 iterations
- # Break and return error to the user
- raise MappingException(
- "Unable to generate a Matrix ID from the SAML response"
- )
-
- logger.info("Mapped SAML user to local part %s", localpart)
-
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=displayname,
- bind_emails=emails,
- user_agent_ips=(user_agent, ip_address),
- )
-
- await self._datastore.record_user_external_id(
- self._auth_provider_id, remote_user_id, registered_user_id
+ return await self._sso_handler.get_mxid_from_sso(
+ self._auth_provider_id,
+ remote_user_id,
+ user_agent,
+ ip_address,
+ saml_response_to_remapped_user_attributes,
)
- return registered_user_id
def expire_sessions(self):
- expire_before = self._clock.time_msec() - self._saml2_session_lifetime
+ expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set()
for reqid, data in self._outstanding_requests_dict.items():
if data.creation_time < expire_before:
@@ -474,11 +416,11 @@ class DefaultSamlMappingProvider:
)
# Use the configured mapper for this mxid_source
- base_mxid_localpart = self._mxid_mapper(mxid_source)
+ localpart = self._mxid_mapper(mxid_source)
# Append suffix integer if last call to this function failed to produce
- # a usable mxid
- localpart = base_mxid_localpart + (str(failures) if failures else "")
+ # a usable mxid.
+ localpart += str(failures) if failures else ""
# Retrieve the display name from the saml response
# If displayname is None, the mxid_localpart will be used instead
|