summary refs log tree commit diff
path: root/synapse/handlers/sso.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/sso.py')
-rw-r--r--synapse/handlers/sso.py96
1 files changed, 88 insertions, 8 deletions
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index ff4750999a..b450668f1c 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -14,7 +14,16 @@
 # limitations under the License.
 import abc
 import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Mapping, Optional
+from typing import (
+    TYPE_CHECKING,
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    Mapping,
+    Optional,
+    Set,
+)
 from urllib.parse import urlencode
 
 import attr
@@ -29,7 +38,7 @@ from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http import get_request_user_agent
 from synapse.http.server import respond_with_html, respond_with_redirect
 from synapse.http.site import SynapseRequest
-from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
+from synapse.types import Collection, JsonDict, UserID, contains_invalid_mxid_characters
 from synapse.util.async_helpers import Linearizer
 from synapse.util.stringutils import random_string
 
@@ -115,7 +124,7 @@ class UserAttributes:
     # enter one.
     localpart = attr.ib(type=Optional[str])
     display_name = attr.ib(type=Optional[str], default=None)
-    emails = attr.ib(type=List[str], default=attr.Factory(list))
+    emails = attr.ib(type=Collection[str], default=attr.Factory(list))
 
 
 @attr.s(slots=True)
@@ -130,7 +139,7 @@ class UsernameMappingSession:
 
     # attributes returned by the ID mapper
     display_name = attr.ib(type=Optional[str])
-    emails = attr.ib(type=List[str])
+    emails = attr.ib(type=Collection[str])
 
     # An optional dictionary of extra attributes to be provided to the client in the
     # login response.
@@ -144,6 +153,9 @@ class UsernameMappingSession:
 
     # choices made by the user
     chosen_localpart = attr.ib(type=Optional[str], default=None)
+    use_display_name = attr.ib(type=bool, default=True)
+    emails_to_use = attr.ib(type=Collection[str], default=())
+    terms_accepted_version = attr.ib(type=Optional[str], default=None)
 
 
 # the HTTP cookie used to track the mapping session id
@@ -179,6 +191,8 @@ class SsoHandler:
         # map from idp_id to SsoIdentityProvider
         self._identity_providers = {}  # type: Dict[str, SsoIdentityProvider]
 
+        self._consent_at_registration = hs.config.consent.user_consent_at_registration
+
     def register_identity_provider(self, p: SsoIdentityProvider):
         p_id = p.idp_id
         assert p_id not in self._identity_providers
@@ -710,7 +724,12 @@ class SsoHandler:
         return not user_infos
 
     async def handle_submit_username_request(
-        self, request: SynapseRequest, localpart: str, session_id: str
+        self,
+        request: SynapseRequest,
+        session_id: str,
+        localpart: str,
+        use_display_name: bool,
+        emails_to_use: Iterable[str],
     ) -> None:
         """Handle a request to the username-picker 'submit' endpoint
 
@@ -720,11 +739,62 @@ class SsoHandler:
             request: HTTP request
             localpart: localpart requested by the user
             session_id: ID of the username mapping session, extracted from a cookie
+            use_display_name: whether the user wants to use the suggested display name
+            emails_to_use: emails that the user would like to use
         """
         session = self.get_mapping_session(session_id)
 
         # update the session with the user's choices
         session.chosen_localpart = localpart
+        session.use_display_name = use_display_name
+
+        emails_from_idp = set(session.emails)
+        filtered_emails = set()  # type: Set[str]
+
+        # we iterate through the list rather than just building a set conjunction, so
+        # that we can log attempts to use unknown addresses
+        for email in emails_to_use:
+            if email in emails_from_idp:
+                filtered_emails.add(email)
+            else:
+                logger.warning(
+                    "[session %s] ignoring user request to use unknown email address %r",
+                    session_id,
+                    email,
+                )
+        session.emails_to_use = filtered_emails
+
+        # we may now need to collect consent from the user, in which case, redirect
+        # to the consent-extraction-unit
+        if self._consent_at_registration:
+            redirect_url = b"/_synapse/client/new_user_consent"
+
+        # otherwise, redirect to the completion page
+        else:
+            redirect_url = b"/_synapse/client/sso_register"
+
+        respond_with_redirect(request, redirect_url)
+
+    async def handle_terms_accepted(
+        self, request: Request, session_id: str, terms_version: str
+    ):
+        """Handle a request to the new-user 'consent' endpoint
+
+        Will serve an HTTP response to the request.
+
+        Args:
+            request: HTTP request
+            session_id: ID of the username mapping session, extracted from a cookie
+            terms_version: the version of the terms which the user viewed and consented
+                to
+        """
+        logger.info(
+            "[session %s] User consented to terms version %s",
+            session_id,
+            terms_version,
+        )
+        session = self.get_mapping_session(session_id)
+        session.terms_accepted_version = terms_version
 
         # we're done; now we can register the user
         respond_with_redirect(request, b"/_synapse/client/sso_register")
@@ -747,11 +817,12 @@ class SsoHandler:
         )
 
         attributes = UserAttributes(
-            localpart=session.chosen_localpart,
-            display_name=session.display_name,
-            emails=session.emails,
+            localpart=session.chosen_localpart, emails=session.emails_to_use,
         )
 
+        if session.use_display_name:
+            attributes.display_name = session.display_name
+
         # the following will raise a 400 error if the username has been taken in the
         # meantime.
         user_id = await self._register_mapped_user(
@@ -780,6 +851,15 @@ class SsoHandler:
             path=b"/",
         )
 
+        auth_result = {}
+        if session.terms_accepted_version:
+            # TODO: make this less awful.
+            auth_result[LoginType.TERMS] = True
+
+        await self._registration_handler.post_registration_actions(
+            user_id, auth_result, access_token=None
+        )
+
         await self._auth_handler.complete_sso_login(
             user_id,
             request,