diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index d7ca2918f8..b450668f1c 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -155,6 +155,7 @@ class UsernameMappingSession:
chosen_localpart = attr.ib(type=Optional[str], default=None)
use_display_name = attr.ib(type=bool, default=True)
emails_to_use = attr.ib(type=Collection[str], default=())
+ terms_accepted_version = attr.ib(type=Optional[str], default=None)
# the HTTP cookie used to track the mapping session id
@@ -190,6 +191,8 @@ class SsoHandler:
# map from idp_id to SsoIdentityProvider
self._identity_providers = {} # type: Dict[str, SsoIdentityProvider]
+ self._consent_at_registration = hs.config.consent.user_consent_at_registration
+
def register_identity_provider(self, p: SsoIdentityProvider):
p_id = p.idp_id
assert p_id not in self._identity_providers
@@ -761,6 +764,38 @@ class SsoHandler:
)
session.emails_to_use = filtered_emails
+ # we may now need to collect consent from the user, in which case, redirect
+ # to the consent-extraction-unit
+ if self._consent_at_registration:
+ redirect_url = b"/_synapse/client/new_user_consent"
+
+ # otherwise, redirect to the completion page
+ else:
+ redirect_url = b"/_synapse/client/sso_register"
+
+ respond_with_redirect(request, redirect_url)
+
+ async def handle_terms_accepted(
+ self, request: Request, session_id: str, terms_version: str
+ ):
+ """Handle a request to the new-user 'consent' endpoint
+
+ Will serve an HTTP response to the request.
+
+ Args:
+ request: HTTP request
+ session_id: ID of the username mapping session, extracted from a cookie
+ terms_version: the version of the terms which the user viewed and consented
+ to
+ """
+ logger.info(
+ "[session %s] User consented to terms version %s",
+ session_id,
+ terms_version,
+ )
+ session = self.get_mapping_session(session_id)
+ session.terms_accepted_version = terms_version
+
# we're done; now we can register the user
respond_with_redirect(request, b"/_synapse/client/sso_register")
@@ -816,6 +851,15 @@ class SsoHandler:
path=b"/",
)
+ auth_result = {}
+ if session.terms_accepted_version:
+ # TODO: make this less awful.
+ auth_result[LoginType.TERMS] = True
+
+ await self._registration_handler.post_registration_actions(
+ user_id, auth_result, access_token=None
+ )
+
await self._auth_handler.complete_sso_login(
user_id,
request,
|