summary refs log tree commit diff
path: root/synapse/handlers/register.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/register.py')
-rw-r--r--synapse/handlers/register.py233
1 files changed, 113 insertions, 120 deletions
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 06bd03b77c..51979ea43e 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -16,18 +16,9 @@
 """Contains functions for registering clients."""
 import logging
 
-from twisted.internet import defer
-
 from synapse import types
 from synapse.api.constants import MAX_USERID_LENGTH, LoginType
-from synapse.api.errors import (
-    AuthError,
-    Codes,
-    ConsentNotGivenError,
-    LimitExceededError,
-    RegistrationError,
-    SynapseError,
-)
+from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
 from synapse.config.server import is_threepid_reserved
 from synapse.http.servlet import assert_params_in_dict
 from synapse.replication.http.login import RegisterDeviceReplicationServlet
@@ -82,8 +73,9 @@ class RegistrationHandler(BaseHandler):
 
         self.session_lifetime = hs.config.session_lifetime
 
-    @defer.inlineCallbacks
-    def check_username(self, localpart, guest_access_token=None, assigned_user_id=None):
+    async def check_username(
+        self, localpart, guest_access_token=None, assigned_user_id=None
+    ):
         if types.contains_invalid_mxid_characters(localpart):
             raise SynapseError(
                 400,
@@ -120,13 +112,13 @@ class RegistrationHandler(BaseHandler):
                 Codes.INVALID_USERNAME,
             )
 
-        users = yield self.store.get_users_by_id_case_insensitive(user_id)
+        users = await self.store.get_users_by_id_case_insensitive(user_id)
         if users:
             if not guest_access_token:
                 raise SynapseError(
                     400, "User ID already taken.", errcode=Codes.USER_IN_USE
                 )
-            user_data = yield self.auth.get_user_by_access_token(guest_access_token)
+            user_data = await self.auth.get_user_by_access_token(guest_access_token)
             if not user_data["is_guest"] or user_data["user"].localpart != localpart:
                 raise AuthError(
                     403,
@@ -135,11 +127,19 @@ class RegistrationHandler(BaseHandler):
                     errcode=Codes.FORBIDDEN,
                 )
 
-    @defer.inlineCallbacks
-    def register_user(
+        if guest_access_token is None:
+            try:
+                int(localpart)
+                raise SynapseError(
+                    400, "Numeric user IDs are reserved for guest users."
+                )
+            except ValueError:
+                pass
+
+    async def register_user(
         self,
         localpart=None,
-        password=None,
+        password_hash=None,
         guest_access_token=None,
         make_guest=False,
         admin=False,
@@ -148,13 +148,14 @@ class RegistrationHandler(BaseHandler):
         default_display_name=None,
         address=None,
         bind_emails=[],
+        by_admin=False,
     ):
         """Registers a new client on the server.
 
         Args:
-            localpart : The local part of the user ID to register. If None,
+            localpart: The local part of the user ID to register. If None,
               one will be generated.
-            password (unicode) : The password to assign to this user so they can
+            password_hash (str|None): The hashed password to assign to this user so they can
               login again. This can be None which means they cannot login again
               via a password (e.g. the user is an application service user).
             user_type (str|None): type of user. One of the values from
@@ -163,31 +164,24 @@ class RegistrationHandler(BaseHandler):
               will be set to this. Defaults to 'localpart'.
             address (str|None): the IP address used to perform the registration.
             bind_emails (List[str]): list of emails to bind to this account.
+            by_admin (bool): True if this registration is being made via the
+              admin api, otherwise False.
         Returns:
-            Deferred[str]: user_id
+            str: user_id
         Raises:
-            RegistrationError if there was a problem registering.
+            SynapseError if there was a problem registering.
         """
+        self.check_registration_ratelimit(address)
 
-        yield self.auth.check_auth_blocking(threepid=threepid)
-        password_hash = None
-        if password:
-            password_hash = yield self._auth_handler.hash(password)
+        # do not check_auth_blocking if the call is coming through the Admin API
+        if not by_admin:
+            await self.auth.check_auth_blocking(threepid=threepid)
 
-        if localpart:
-            yield self.check_username(localpart, guest_access_token=guest_access_token)
+        if localpart is not None:
+            await self.check_username(localpart, guest_access_token=guest_access_token)
 
             was_guest = guest_access_token is not None
 
-            if not was_guest:
-                try:
-                    int(localpart)
-                    raise RegistrationError(
-                        400, "Numeric user IDs are reserved for guest users."
-                    )
-                except ValueError:
-                    pass
-
             user = UserID(localpart, self.hs.hostname)
             user_id = user.to_string()
 
@@ -198,7 +192,7 @@ class RegistrationHandler(BaseHandler):
             elif default_display_name is None:
                 default_display_name = localpart
 
-            yield self.register_with_store(
+            await self.register_with_store(
                 user_id=user_id,
                 password_hash=password_hash,
                 was_guest=was_guest,
@@ -210,38 +204,51 @@ class RegistrationHandler(BaseHandler):
             )
 
             if self.hs.config.user_directory_search_all_users:
-                profile = yield self.store.get_profileinfo(localpart)
-                yield self.user_directory_handler.handle_local_profile_change(
+                profile = await self.store.get_profileinfo(localpart)
+                await self.user_directory_handler.handle_local_profile_change(
                     user_id, profile
                 )
 
         else:
             # autogen a sequential user ID
-            attempts = 0
+            fail_count = 0
             user = None
             while not user:
-                localpart = yield self._generate_user_id(attempts > 0)
+                # Fail after being unable to find a suitable ID a few times
+                if fail_count > 10:
+                    raise SynapseError(500, "Unable to find a suitable guest user ID")
+
+                localpart = await self._generate_user_id()
                 user = UserID(localpart, self.hs.hostname)
                 user_id = user.to_string()
-                yield self.check_user_id_not_appservice_exclusive(user_id)
+                self.check_user_id_not_appservice_exclusive(user_id)
                 if default_display_name is None:
                     default_display_name = localpart
                 try:
-                    yield self.register_with_store(
+                    await self.register_with_store(
                         user_id=user_id,
                         password_hash=password_hash,
                         make_guest=make_guest,
                         create_profile_with_displayname=default_display_name,
                         address=address,
                     )
+
+                    # Successfully registered
+                    break
                 except SynapseError:
                     # if user id is taken, just generate another
                     user = None
                     user_id = None
-                    attempts += 1
+                    fail_count += 1
 
         if not self.hs.config.user_consent_at_registration:
-            yield self._auto_join_rooms(user_id)
+            if not self.hs.config.auto_join_rooms_for_guests and make_guest:
+                logger.info(
+                    "Skipping auto-join for %s because auto-join for guests is disabled",
+                    user_id,
+                )
+            else:
+                await self._auto_join_rooms(user_id)
         else:
             logger.info(
                 "Skipping auto-join for %s because consent is required at registration",
@@ -259,12 +266,11 @@ class RegistrationHandler(BaseHandler):
             }
 
             # Bind email to new account
-            yield self._register_email_threepid(user_id, threepid_dict, None, False)
+            await self._register_email_threepid(user_id, threepid_dict, None)
 
         return user_id
 
-    @defer.inlineCallbacks
-    def _auto_join_rooms(self, user_id):
+    async def _auto_join_rooms(self, user_id):
         """Automatically joins users to auto join rooms - creating the room in the first place
         if the user is the first to be created.
 
@@ -278,9 +284,9 @@ class RegistrationHandler(BaseHandler):
         # that an auto-generated support or bot user is not a real user and will never be
         # the user to create the room
         should_auto_create_rooms = False
-        is_real_user = yield self.store.is_real_user(user_id)
+        is_real_user = await self.store.is_real_user(user_id)
         if self.hs.config.autocreate_auto_join_rooms and is_real_user:
-            count = yield self.store.count_real_users()
+            count = await self.store.count_real_users()
             should_auto_create_rooms = count == 1
         for r in self.hs.config.auto_join_rooms:
             logger.info("Auto-joining %s to %s", user_id, r)
@@ -299,7 +305,7 @@ class RegistrationHandler(BaseHandler):
 
                         # getting the RoomCreationHandler during init gives a dependency
                         # loop
-                        yield self.hs.get_room_creation_handler().create_room(
+                        await self.hs.get_room_creation_handler().create_room(
                             fake_requester,
                             config={
                                 "preset": "public_chat",
@@ -308,7 +314,7 @@ class RegistrationHandler(BaseHandler):
                             ratelimit=False,
                         )
                 else:
-                    yield self._join_user_to_room(fake_requester, r)
+                    await self._join_user_to_room(fake_requester, r)
             except ConsentNotGivenError as e:
                 # Technically not necessary to pull out this error though
                 # moving away from bare excepts is a good thing to do.
@@ -316,18 +322,16 @@ class RegistrationHandler(BaseHandler):
             except Exception as e:
                 logger.error("Failed to join new user to %r: %r", r, e)
 
-    @defer.inlineCallbacks
-    def post_consent_actions(self, user_id):
+    async def post_consent_actions(self, user_id):
         """A series of registration actions that can only be carried out once consent
         has been granted
 
         Args:
             user_id (str): The user to join
         """
-        yield self._auto_join_rooms(user_id)
+        await self._auto_join_rooms(user_id)
 
-    @defer.inlineCallbacks
-    def appservice_register(self, user_localpart, as_token):
+    async def appservice_register(self, user_localpart, as_token):
         user = UserID(user_localpart, self.hs.hostname)
         user_id = user.to_string()
         service = self.store.get_app_service_by_token(as_token)
@@ -342,11 +346,9 @@ class RegistrationHandler(BaseHandler):
 
         service_id = service.id if service.is_exclusive_user(user_id) else None
 
-        yield self.check_user_id_not_appservice_exclusive(
-            user_id, allowed_appservice=service
-        )
+        self.check_user_id_not_appservice_exclusive(user_id, allowed_appservice=service)
 
-        yield self.register_with_store(
+        await self.register_with_store(
             user_id=user_id,
             password_hash="",
             appservice_id=service_id,
@@ -378,28 +380,26 @@ class RegistrationHandler(BaseHandler):
                     errcode=Codes.EXCLUSIVE,
                 )
 
-    @defer.inlineCallbacks
-    def _generate_user_id(self, reseed=False):
-        if reseed or self._next_generated_user_id is None:
-            with (yield self._generate_user_id_linearizer.queue(())):
-                if reseed or self._next_generated_user_id is None:
+    async def _generate_user_id(self):
+        if self._next_generated_user_id is None:
+            with await self._generate_user_id_linearizer.queue(()):
+                if self._next_generated_user_id is None:
                     self._next_generated_user_id = (
-                        yield self.store.find_next_generated_user_id_localpart()
+                        await self.store.find_next_generated_user_id_localpart()
                     )
 
         id = self._next_generated_user_id
         self._next_generated_user_id += 1
         return str(id)
 
-    @defer.inlineCallbacks
-    def _join_user_to_room(self, requester, room_identifier):
+    async def _join_user_to_room(self, requester, room_identifier):
         room_member_handler = self.hs.get_room_member_handler()
         if RoomID.is_valid(room_identifier):
             room_id = room_identifier
         elif RoomAlias.is_valid(room_identifier):
             room_alias = RoomAlias.from_string(room_identifier)
-            room_id, remote_room_hosts = (
-                yield room_member_handler.lookup_room_alias(room_alias)
+            room_id, remote_room_hosts = await room_member_handler.lookup_room_alias(
+                room_alias
             )
             room_id = room_id.to_string()
         else:
@@ -407,7 +407,7 @@ class RegistrationHandler(BaseHandler):
                 400, "%s was not legal room ID or room alias" % (room_identifier,)
             )
 
-        yield room_member_handler.update_membership(
+        await room_member_handler.update_membership(
             requester=requester,
             target=requester.user,
             room_id=room_id,
@@ -416,6 +416,22 @@ class RegistrationHandler(BaseHandler):
             ratelimit=False,
         )
 
+    def check_registration_ratelimit(self, address):
+        """A simple helper method to check whether the registration rate limit has been hit
+        for a given IP address
+
+        Args:
+            address (str|None): the IP address used to perform the registration. If this is
+                None, no ratelimiting will be performed.
+
+        Raises:
+            LimitExceededError: If the rate limit has been exceeded.
+        """
+        if not address:
+            return
+
+        self.ratelimiter.ratelimit(address)
+
     def register_with_store(
         self,
         user_id,
@@ -448,22 +464,6 @@ class RegistrationHandler(BaseHandler):
         Returns:
             Deferred
         """
-        # Don't rate limit for app services
-        if appservice_id is None and address is not None:
-            time_now = self.clock.time()
-
-            allowed, time_allowed = self.ratelimiter.can_do_action(
-                address,
-                time_now_s=time_now,
-                rate_hz=self.hs.config.rc_registration.per_second,
-                burst_count=self.hs.config.rc_registration.burst_count,
-            )
-
-            if not allowed:
-                raise LimitExceededError(
-                    retry_after_ms=int(1000 * (time_allowed - time_now))
-                )
-
         if self.hs.config.worker_app:
             return self._register_client(
                 user_id=user_id,
@@ -488,8 +488,9 @@ class RegistrationHandler(BaseHandler):
                 user_type=user_type,
             )
 
-    @defer.inlineCallbacks
-    def register_device(self, user_id, device_id, initial_display_name, is_guest=False):
+    async def register_device(
+        self, user_id, device_id, initial_display_name, is_guest=False
+    ):
         """Register a device for a user and generate an access token.
 
         The access token will be limited by the homeserver's session_lifetime config.
@@ -503,11 +504,11 @@ class RegistrationHandler(BaseHandler):
             is_guest (bool): Whether this is a guest account
 
         Returns:
-            defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
+            tuple[str, str]: Tuple of device ID and access token
         """
 
         if self.hs.config.worker_app:
-            r = yield self._register_device_client(
+            r = await self._register_device_client(
                 user_id=user_id,
                 device_id=device_id,
                 initial_display_name=initial_display_name,
@@ -523,7 +524,7 @@ class RegistrationHandler(BaseHandler):
                 )
             valid_until_ms = self.clock.time_msec() + self.session_lifetime
 
-        device_id = yield self.device_handler.check_device_registered(
+        device_id = await self.device_handler.check_device_registered(
             user_id, device_id, initial_display_name
         )
         if is_guest:
@@ -532,14 +533,13 @@ class RegistrationHandler(BaseHandler):
                 user_id, ["guest = true"]
             )
         else:
-            access_token = yield self._auth_handler.get_access_token_for_user_id(
+            access_token = await self._auth_handler.get_access_token_for_user_id(
                 user_id, device_id=device_id, valid_until_ms=valid_until_ms
             )
 
         return (device_id, access_token)
 
-    @defer.inlineCallbacks
-    def post_registration_actions(self, user_id, auth_result, access_token):
+    async def post_registration_actions(self, user_id, auth_result, access_token):
         """A user has completed registration
 
         Args:
@@ -550,7 +550,7 @@ class RegistrationHandler(BaseHandler):
                 device, or None if `inhibit_login` enabled.
         """
         if self.hs.config.worker_app:
-            yield self._post_registration_client(
+            await self._post_registration_client(
                 user_id=user_id, auth_result=auth_result, access_token=access_token
             )
             return
@@ -562,19 +562,18 @@ class RegistrationHandler(BaseHandler):
             if is_threepid_reserved(
                 self.hs.config.mau_limits_reserved_threepids, threepid
             ):
-                yield self.store.upsert_monthly_active_user(user_id)
+                await self.store.upsert_monthly_active_user(user_id)
 
-            yield self._register_email_threepid(user_id, threepid, access_token)
+            await self._register_email_threepid(user_id, threepid, access_token)
 
         if auth_result and LoginType.MSISDN in auth_result:
             threepid = auth_result[LoginType.MSISDN]
-            yield self._register_msisdn_threepid(user_id, threepid)
+            await self._register_msisdn_threepid(user_id, threepid)
 
         if auth_result and LoginType.TERMS in auth_result:
-            yield self._on_user_consented(user_id, self.hs.config.user_consent_version)
+            await self._on_user_consented(user_id, self.hs.config.user_consent_version)
 
-    @defer.inlineCallbacks
-    def _on_user_consented(self, user_id, consent_version):
+    async def _on_user_consented(self, user_id, consent_version):
         """A user consented to the terms on registration
 
         Args:
@@ -583,11 +582,10 @@ class RegistrationHandler(BaseHandler):
                 consented to.
         """
         logger.info("%s has consented to the privacy policy", user_id)
-        yield self.store.user_set_consent_version(user_id, consent_version)
-        yield self.post_consent_actions(user_id)
+        await self.store.user_set_consent_version(user_id, consent_version)
+        await self.post_consent_actions(user_id)
 
-    @defer.inlineCallbacks
-    def _register_email_threepid(self, user_id, threepid, token):
+    async def _register_email_threepid(self, user_id, threepid, token):
         """Add an email address as a 3pid identifier
 
         Also adds an email pusher for the email address, if configured in the
@@ -600,8 +598,6 @@ class RegistrationHandler(BaseHandler):
             threepid (object): m.login.email.identity auth response
             token (str|None): access_token for the user, or None if not logged
                 in.
-        Returns:
-            defer.Deferred:
         """
         reqd = ("medium", "address", "validated_at")
         if any(x not in threepid for x in reqd):
@@ -609,14 +605,14 @@ class RegistrationHandler(BaseHandler):
             logger.info("Can't add incomplete 3pid")
             return
 
-        yield self._auth_handler.add_threepid(
-            user_id, threepid["medium"], threepid["address"], threepid["validated_at"]
+        await self._auth_handler.add_threepid(
+            user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
         )
 
         # And we add an email pusher for them by default, but only
         # if email notifications are enabled (so people don't start
         # getting mail spam where they weren't before if email
-        # notifs are set up on a home server)
+        # notifs are set up on a homeserver)
         if (
             self.hs.config.email_enable_notifs
             and self.hs.config.email_notif_for_new_users
@@ -626,10 +622,10 @@ class RegistrationHandler(BaseHandler):
             # It would really make more sense for this to be passed
             # up when the access token is saved, but that's quite an
             # invasive change I'd rather do separately.
-            user_tuple = yield self.store.get_user_by_access_token(token)
+            user_tuple = await self.store.get_user_by_access_token(token)
             token_id = user_tuple["token_id"]
 
-            yield self.pusher_pool.add_pusher(
+            await self.pusher_pool.add_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="email",
@@ -641,8 +637,7 @@ class RegistrationHandler(BaseHandler):
                 data={},
             )
 
-    @defer.inlineCallbacks
-    def _register_msisdn_threepid(self, user_id, threepid):
+    async def _register_msisdn_threepid(self, user_id, threepid):
         """Add a phone number as a 3pid identifier
 
         Must be called on master.
@@ -650,8 +645,6 @@ class RegistrationHandler(BaseHandler):
         Args:
             user_id (str): id of user
             threepid (object): m.login.msisdn auth response
-        Returns:
-            defer.Deferred:
         """
         try:
             assert_params_in_dict(threepid, ["medium", "address", "validated_at"])
@@ -662,6 +655,6 @@ class RegistrationHandler(BaseHandler):
                 return None
             raise
 
-        yield self._auth_handler.add_threepid(
-            user_id, threepid["medium"], threepid["address"], threepid["validated_at"]
+        await self._auth_handler.add_threepid(
+            user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
         )