summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/cas_handler.py71
-rw-r--r--synapse/handlers/oidc_handler.py2
-rw-r--r--synapse/handlers/register.py214
-rw-r--r--synapse/handlers/saml_handler.py6
4 files changed, 171 insertions, 122 deletions
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 048a3b3c0b..f4ea0a9767 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import logging
 import urllib
-from typing import Dict, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Optional, Tuple
 from xml.etree import ElementTree as ET
 
 from twisted.web.client import PartialDownloadError
@@ -23,6 +23,9 @@ from synapse.api.errors import Codes, LoginError
 from synapse.http.site import SynapseRequest
 from synapse.types import UserID, map_username_to_mxid_localpart
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -31,10 +34,10 @@ class CasHandler:
     Utility class for to handle the response from a CAS SSO service.
 
     Args:
-        hs (synapse.server.HomeServer)
+        hs
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self._hostname = hs.hostname
         self._auth_handler = hs.get_auth_handler()
@@ -200,27 +203,57 @@ class CasHandler:
             args["session"] = session
         username, user_display_name = await self._validate_ticket(ticket, args)
 
-        localpart = map_username_to_mxid_localpart(username)
-        user_id = UserID(localpart, self._hostname).to_string()
-        registered_user_id = await self._auth_handler.check_user_exists(user_id)
+        # Pull out the user-agent and IP from the request.
+        user_agent = request.get_user_agent("")
+        ip_address = self.hs.get_ip_from_request(request)
+
+        # Get the matrix ID from the CAS username.
+        user_id = await self._map_cas_user_to_matrix_user(
+            username, user_display_name, user_agent, ip_address
+        )
 
         if session:
             await self._auth_handler.complete_sso_ui_auth(
-                registered_user_id, session, request,
+                user_id, session, request,
             )
-
         else:
-            if not registered_user_id:
-                # Pull out the user-agent and IP from the request.
-                user_agent = request.get_user_agent("")
-                ip_address = self.hs.get_ip_from_request(request)
-
-                registered_user_id = await self._registration_handler.register_user(
-                    localpart=localpart,
-                    default_display_name=user_display_name,
-                    user_agent_ips=(user_agent, ip_address),
-                )
+            # If this not a UI auth request than there must be a redirect URL.
+            assert client_redirect_url
 
             await self._auth_handler.complete_sso_login(
-                registered_user_id, request, client_redirect_url
+                user_id, request, client_redirect_url
             )
+
+    async def _map_cas_user_to_matrix_user(
+        self,
+        remote_user_id: str,
+        display_name: Optional[str],
+        user_agent: str,
+        ip_address: str,
+    ) -> str:
+        """
+        Given a CAS username, retrieve the user ID for it and possibly register the user.
+
+        Args:
+            remote_user_id: The username from the CAS response.
+            display_name: The display name from the CAS response.
+            user_agent: The user agent of the client making the request.
+            ip_address: The IP address of the client making the request.
+
+        Returns:
+             The user ID associated with this response.
+        """
+
+        localpart = map_username_to_mxid_localpart(remote_user_id)
+        user_id = UserID(localpart, self._hostname).to_string()
+        registered_user_id = await self._auth_handler.check_user_exists(user_id)
+
+        # If the user does not exist, register it.
+        if not registered_user_id:
+            registered_user_id = await self._registration_handler.register_user(
+                localpart=localpart,
+                default_display_name=display_name,
+                user_agent_ips=[(user_agent, ip_address)],
+            )
+
+        return registered_user_id
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 4bfd8d5617..34de9109ea 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -925,7 +925,7 @@ class OidcHandler(BaseHandler):
             registered_user_id = await self._registration_handler.register_user(
                 localpart=localpart,
                 default_display_name=attributes["display_name"],
-                user_agent_ips=(user_agent, ip_address),
+                user_agent_ips=[(user_agent, ip_address)],
             )
 
         await self.store.record_user_external_id(
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 252f700786..0d85fd0868 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -15,10 +15,12 @@
 
 """Contains functions for registering clients."""
 import logging
+from typing import TYPE_CHECKING, List, Optional, Tuple
 
 from synapse import types
 from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
 from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
+from synapse.appservice import ApplicationService
 from synapse.config.server import is_threepid_reserved
 from synapse.http.servlet import assert_params_in_dict
 from synapse.replication.http.login import RegisterDeviceReplicationServlet
@@ -32,16 +34,14 @@ from synapse.types import RoomAlias, UserID, create_requester
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class RegistrationHandler(BaseHandler):
-    def __init__(self, hs):
-        """
-
-        Args:
-            hs (synapse.server.HomeServer):
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.hs = hs
         self.auth = hs.get_auth()
@@ -71,7 +71,10 @@ class RegistrationHandler(BaseHandler):
         self.session_lifetime = hs.config.session_lifetime
 
     async def check_username(
-        self, localpart, guest_access_token=None, assigned_user_id=None
+        self,
+        localpart: str,
+        guest_access_token: Optional[str] = None,
+        assigned_user_id: Optional[str] = None,
     ):
         if types.contains_invalid_mxid_characters(localpart):
             raise SynapseError(
@@ -140,39 +143,45 @@ class RegistrationHandler(BaseHandler):
 
     async def register_user(
         self,
-        localpart=None,
-        password_hash=None,
-        guest_access_token=None,
-        make_guest=False,
-        admin=False,
-        threepid=None,
-        user_type=None,
-        default_display_name=None,
-        address=None,
-        bind_emails=[],
-        by_admin=False,
-        user_agent_ips=None,
-    ):
+        localpart: Optional[str] = None,
+        password_hash: Optional[str] = None,
+        guest_access_token: Optional[str] = None,
+        make_guest: bool = False,
+        admin: bool = False,
+        threepid: Optional[dict] = None,
+        user_type: Optional[str] = None,
+        default_display_name: Optional[str] = None,
+        address: Optional[str] = None,
+        bind_emails: List[str] = [],
+        by_admin: bool = False,
+        user_agent_ips: Optional[List[Tuple[str, str]]] = None,
+    ) -> str:
         """Registers a new client on the server.
 
         Args:
             localpart: The local part of the user ID to register. If None,
               one will be generated.
-            password_hash (str|None): The hashed password to assign to this user so they can
+            password_hash: The hashed password to assign to this user so they can
               login again. This can be None which means they cannot login again
               via a password (e.g. the user is an application service user).
-            user_type (str|None): type of user. One of the values from
+            guest_access_token: The access token used when this was a guest
+                account.
+            make_guest: True if the the new user should be guest,
+                false to add a regular user account.
+            admin: True if the user should be registered as a server admin.
+            threepid: The threepid used for registering, if any.
+            user_type: type of user. One of the values from
               api.constants.UserTypes, or None for a normal user.
-            default_display_name (unicode|None): if set, the new user's displayname
+            default_display_name: if set, the new user's displayname
               will be set to this. Defaults to 'localpart'.
-            address (str|None): the IP address used to perform the registration.
-            bind_emails (List[str]): list of emails to bind to this account.
-            by_admin (bool): True if this registration is being made via the
+            address: the IP address used to perform the registration.
+            bind_emails: list of emails to bind to this account.
+            by_admin: True if this registration is being made via the
               admin api, otherwise False.
-            user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used
+            user_agent_ips: Tuples of IP addresses and user-agents used
                 during the registration process.
         Returns:
-            str: user_id
+            The registere user_id.
         Raises:
             SynapseError if there was a problem registering.
         """
@@ -236,8 +245,10 @@ class RegistrationHandler(BaseHandler):
         else:
             # autogen a sequential user ID
             fail_count = 0
-            user = None
-            while not user:
+            # If a default display name is not given, generate one.
+            generate_display_name = default_display_name is None
+            # This breaks on successful registration *or* errors after 10 failures.
+            while True:
                 # Fail after being unable to find a suitable ID a few times
                 if fail_count > 10:
                     raise SynapseError(500, "Unable to find a suitable guest user ID")
@@ -246,7 +257,7 @@ class RegistrationHandler(BaseHandler):
                 user = UserID(localpart, self.hs.hostname)
                 user_id = user.to_string()
                 self.check_user_id_not_appservice_exclusive(user_id)
-                if default_display_name is None:
+                if generate_display_name:
                     default_display_name = localpart
                 try:
                     await self.register_with_store(
@@ -262,8 +273,6 @@ class RegistrationHandler(BaseHandler):
                     break
                 except SynapseError:
                     # if user id is taken, just generate another
-                    user = None
-                    user_id = None
                     fail_count += 1
 
         if not self.hs.config.user_consent_at_registration:
@@ -295,7 +304,7 @@ class RegistrationHandler(BaseHandler):
 
         return user_id
 
-    async def _create_and_join_rooms(self, user_id: str):
+    async def _create_and_join_rooms(self, user_id: str) -> None:
         """
         Create the auto-join rooms and join or invite the user to them.
 
@@ -379,7 +388,7 @@ class RegistrationHandler(BaseHandler):
             except Exception as e:
                 logger.error("Failed to join new user to %r: %r", r, e)
 
-    async def _join_rooms(self, user_id: str):
+    async def _join_rooms(self, user_id: str) -> None:
         """
         Join or invite the user to the auto-join rooms.
 
@@ -425,6 +434,9 @@ class RegistrationHandler(BaseHandler):
 
                 # Send the invite, if necessary.
                 if requires_invite:
+                    # If an invite is required, there must be a auto-join user ID.
+                    assert self.hs.config.registration.auto_join_user_id
+
                     await room_member_handler.update_membership(
                         requester=create_requester(
                             self.hs.config.registration.auto_join_user_id,
@@ -456,7 +468,7 @@ class RegistrationHandler(BaseHandler):
             except Exception as e:
                 logger.error("Failed to join new user to %r: %r", r, e)
 
-    async def _auto_join_rooms(self, user_id: str):
+    async def _auto_join_rooms(self, user_id: str) -> None:
         """Automatically joins users to auto join rooms - creating the room in the first place
         if the user is the first to be created.
 
@@ -479,16 +491,16 @@ class RegistrationHandler(BaseHandler):
         else:
             await self._join_rooms(user_id)
 
-    async def post_consent_actions(self, user_id):
+    async def post_consent_actions(self, user_id: str) -> None:
         """A series of registration actions that can only be carried out once consent
         has been granted
 
         Args:
-            user_id (str): The user to join
+            user_id: The user to join
         """
         await self._auto_join_rooms(user_id)
 
-    async def appservice_register(self, user_localpart, as_token):
+    async def appservice_register(self, user_localpart: str, as_token: str) -> str:
         user = UserID(user_localpart, self.hs.hostname)
         user_id = user.to_string()
         service = self.store.get_app_service_by_token(as_token)
@@ -513,7 +525,9 @@ class RegistrationHandler(BaseHandler):
         )
         return user_id
 
-    def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
+    def check_user_id_not_appservice_exclusive(
+        self, user_id: str, allowed_appservice: Optional[ApplicationService] = None
+    ) -> None:
         # don't allow people to register the server notices mxid
         if self._server_notices_mxid is not None:
             if user_id == self._server_notices_mxid:
@@ -537,12 +551,12 @@ class RegistrationHandler(BaseHandler):
                     errcode=Codes.EXCLUSIVE,
                 )
 
-    def check_registration_ratelimit(self, address):
+    def check_registration_ratelimit(self, address: Optional[str]) -> None:
         """A simple helper method to check whether the registration rate limit has been hit
         for a given IP address
 
         Args:
-            address (str|None): the IP address used to perform the registration. If this is
+            address: the IP address used to perform the registration. If this is
                 None, no ratelimiting will be performed.
 
         Raises:
@@ -553,42 +567,39 @@ class RegistrationHandler(BaseHandler):
 
         self.ratelimiter.ratelimit(address)
 
-    def register_with_store(
+    async def register_with_store(
         self,
-        user_id,
-        password_hash=None,
-        was_guest=False,
-        make_guest=False,
-        appservice_id=None,
-        create_profile_with_displayname=None,
-        admin=False,
-        user_type=None,
-        address=None,
-        shadow_banned=False,
-    ):
+        user_id: str,
+        password_hash: Optional[str] = None,
+        was_guest: bool = False,
+        make_guest: bool = False,
+        appservice_id: Optional[str] = None,
+        create_profile_with_displayname: Optional[str] = None,
+        admin: bool = False,
+        user_type: Optional[str] = None,
+        address: Optional[str] = None,
+        shadow_banned: bool = False,
+    ) -> None:
         """Register user in the datastore.
 
         Args:
-            user_id (str): The desired user ID to register.
-            password_hash (str|None): Optional. The password hash for this user.
-            was_guest (bool): Optional. Whether this is a guest account being
+            user_id: The desired user ID to register.
+            password_hash: Optional. The password hash for this user.
+            was_guest: Optional. Whether this is a guest account being
                 upgraded to a non-guest account.
-            make_guest (boolean): True if the the new user should be guest,
+            make_guest: True if the the new user should be guest,
                 false to add a regular user account.
-            appservice_id (str|None): The ID of the appservice registering the user.
-            create_profile_with_displayname (unicode|None): Optionally create a
+            appservice_id: The ID of the appservice registering the user.
+            create_profile_with_displayname: Optionally create a
                 profile for the user, setting their displayname to the given value
-            admin (boolean): is an admin user?
-            user_type (str|None): type of user. One of the values from
+            admin: is an admin user?
+            user_type: type of user. One of the values from
                 api.constants.UserTypes, or None for a normal user.
-            address (str|None): the IP address used to perform the registration.
-            shadow_banned (bool): Whether to shadow-ban the user
-
-        Returns:
-            Awaitable
+            address: the IP address used to perform the registration.
+            shadow_banned: Whether to shadow-ban the user
         """
         if self.hs.config.worker_app:
-            return self._register_client(
+            await self._register_client(
                 user_id=user_id,
                 password_hash=password_hash,
                 was_guest=was_guest,
@@ -601,7 +612,7 @@ class RegistrationHandler(BaseHandler):
                 shadow_banned=shadow_banned,
             )
         else:
-            return self.store.register_user(
+            await self.store.register_user(
                 user_id=user_id,
                 password_hash=password_hash,
                 was_guest=was_guest,
@@ -614,22 +625,24 @@ class RegistrationHandler(BaseHandler):
             )
 
     async def register_device(
-        self, user_id, device_id, initial_display_name, is_guest=False
-    ):
+        self,
+        user_id: str,
+        device_id: Optional[str],
+        initial_display_name: Optional[str],
+        is_guest: bool = False,
+    ) -> Tuple[str, str]:
         """Register a device for a user and generate an access token.
 
         The access token will be limited by the homeserver's session_lifetime config.
 
         Args:
-            user_id (str): full canonical @user:id
-            device_id (str|None): The device ID to check, or None to generate
-                a new one.
-            initial_display_name (str|None): An optional display name for the
-                device.
-            is_guest (bool): Whether this is a guest account
+            user_id: full canonical @user:id
+            device_id: The device ID to check, or None to generate a new one.
+            initial_display_name: An optional display name for the device.
+            is_guest: Whether this is a guest account
 
         Returns:
-            tuple[str, str]: Tuple of device ID and access token
+            Tuple of device ID and access token
         """
 
         if self.hs.config.worker_app:
@@ -649,7 +662,7 @@ class RegistrationHandler(BaseHandler):
                 )
             valid_until_ms = self.clock.time_msec() + self.session_lifetime
 
-        device_id = await self.device_handler.check_device_registered(
+        registered_device_id = await self.device_handler.check_device_registered(
             user_id, device_id, initial_display_name
         )
         if is_guest:
@@ -659,20 +672,21 @@ class RegistrationHandler(BaseHandler):
             )
         else:
             access_token = await self._auth_handler.get_access_token_for_user_id(
-                user_id, device_id=device_id, valid_until_ms=valid_until_ms
+                user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms
             )
 
-        return (device_id, access_token)
+        return (registered_device_id, access_token)
 
-    async def post_registration_actions(self, user_id, auth_result, access_token):
+    async def post_registration_actions(
+        self, user_id: str, auth_result: dict, access_token: Optional[str]
+    ) -> None:
         """A user has completed registration
 
         Args:
-            user_id (str): The user ID that consented
-            auth_result (dict): The authenticated credentials of the newly
-                registered user.
-            access_token (str|None): The access token of the newly logged in
-                device, or None if `inhibit_login` enabled.
+            user_id: The user ID that consented
+            auth_result: The authenticated credentials of the newly registered user.
+            access_token: The access token of the newly logged in device, or
+                None if `inhibit_login` enabled.
         """
         if self.hs.config.worker_app:
             await self._post_registration_client(
@@ -698,19 +712,20 @@ class RegistrationHandler(BaseHandler):
         if auth_result and LoginType.TERMS in auth_result:
             await self._on_user_consented(user_id, self.hs.config.user_consent_version)
 
-    async def _on_user_consented(self, user_id, consent_version):
+    async def _on_user_consented(self, user_id: str, consent_version: str) -> None:
         """A user consented to the terms on registration
 
         Args:
-            user_id (str): The user ID that consented.
-            consent_version (str): version of the policy the user has
-                consented to.
+            user_id: The user ID that consented.
+            consent_version: version of the policy the user has consented to.
         """
         logger.info("%s has consented to the privacy policy", user_id)
         await self.store.user_set_consent_version(user_id, consent_version)
         await self.post_consent_actions(user_id)
 
-    async def _register_email_threepid(self, user_id, threepid, token):
+    async def _register_email_threepid(
+        self, user_id: str, threepid: dict, token: Optional[str]
+    ) -> None:
         """Add an email address as a 3pid identifier
 
         Also adds an email pusher for the email address, if configured in the
@@ -719,10 +734,9 @@ class RegistrationHandler(BaseHandler):
         Must be called on master.
 
         Args:
-            user_id (str): id of user
-            threepid (object): m.login.email.identity auth response
-            token (str|None): access_token for the user, or None if not logged
-                in.
+            user_id: id of user
+            threepid: m.login.email.identity auth response
+            token: access_token for the user, or None if not logged in.
         """
         reqd = ("medium", "address", "validated_at")
         if any(x not in threepid for x in reqd):
@@ -748,6 +762,8 @@ class RegistrationHandler(BaseHandler):
             # up when the access token is saved, but that's quite an
             # invasive change I'd rather do separately.
             user_tuple = await self.store.get_user_by_access_token(token)
+            # The token better still exist.
+            assert user_tuple
             token_id = user_tuple.token_id
 
             await self.pusher_pool.add_pusher(
@@ -762,14 +778,14 @@ class RegistrationHandler(BaseHandler):
                 data={},
             )
 
-    async def _register_msisdn_threepid(self, user_id, threepid):
+    async def _register_msisdn_threepid(self, user_id: str, threepid: dict) -> None:
         """Add a phone number as a 3pid identifier
 
         Must be called on master.
 
         Args:
-            user_id (str): id of user
-            threepid (object): m.login.msisdn auth response
+            user_id: id of user
+            threepid: m.login.msisdn auth response
         """
         try:
             assert_params_in_dict(threepid, ["medium", "address", "validated_at"])
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index f4e8cbeac8..37ab42f050 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -39,7 +39,7 @@ from synapse.util.async_helpers import Linearizer
 from synapse.util.iterutils import chunk_seq
 
 if TYPE_CHECKING:
-    import synapse.server
+    from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -56,7 +56,7 @@ class Saml2SessionData:
 
 
 class SamlHandler(BaseHandler):
-    def __init__(self, hs: "synapse.server.HomeServer"):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self._saml_client = Saml2Client(hs.config.saml2_sp_config)
         self._saml_idp_entityid = hs.config.saml2_idp_entityid
@@ -330,7 +330,7 @@ class SamlHandler(BaseHandler):
                 localpart=localpart,
                 default_display_name=displayname,
                 bind_emails=emails,
-                user_agent_ips=(user_agent, ip_address),
+                user_agent_ips=[(user_agent, ip_address)],
             )
 
             await self.store.record_user_external_id(