summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-06-08 11:15:02 -0400
committerGitHub <noreply@github.com>2020-06-08 11:15:02 -0400
commit3c45a7809036126a44636f8aaffd42bbc633b9ac (patch)
tree3835c5cc9f3fa2a39c55dbb88b0fe28cf134e4c9
parentAccept device information at the login fallback endpoint. (#7629) (diff)
downloadsynapse-3c45a7809036126a44636f8aaffd42bbc633b9ac.tar.xz
Convert the registration handler to async/await. (#7649)
-rw-r--r--changelog.d/7649.misc1
-rw-r--r--synapse/handlers/register.py107
-rw-r--r--synapse/module_api/__init__.py8
3 files changed, 48 insertions, 68 deletions
diff --git a/changelog.d/7649.misc b/changelog.d/7649.misc
new file mode 100644
index 0000000000..8a26c8b3b7
--- /dev/null
+++ b/changelog.d/7649.misc
@@ -0,0 +1 @@
+Convert registration handler to async/await.
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index af812dbda9..51979ea43e 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -16,8 +16,6 @@
 """Contains functions for registering clients."""
 import logging
 
-from twisted.internet import defer
-
 from synapse import types
 from synapse.api.constants import MAX_USERID_LENGTH, LoginType
 from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
@@ -75,8 +73,9 @@ class RegistrationHandler(BaseHandler):
 
         self.session_lifetime = hs.config.session_lifetime
 
-    @defer.inlineCallbacks
-    def check_username(self, localpart, guest_access_token=None, assigned_user_id=None):
+    async def check_username(
+        self, localpart, guest_access_token=None, assigned_user_id=None
+    ):
         if types.contains_invalid_mxid_characters(localpart):
             raise SynapseError(
                 400,
@@ -113,13 +112,13 @@ class RegistrationHandler(BaseHandler):
                 Codes.INVALID_USERNAME,
             )
 
-        users = yield self.store.get_users_by_id_case_insensitive(user_id)
+        users = await self.store.get_users_by_id_case_insensitive(user_id)
         if users:
             if not guest_access_token:
                 raise SynapseError(
                     400, "User ID already taken.", errcode=Codes.USER_IN_USE
                 )
-            user_data = yield self.auth.get_user_by_access_token(guest_access_token)
+            user_data = await self.auth.get_user_by_access_token(guest_access_token)
             if not user_data["is_guest"] or user_data["user"].localpart != localpart:
                 raise AuthError(
                     403,
@@ -137,8 +136,7 @@ class RegistrationHandler(BaseHandler):
             except ValueError:
                 pass
 
-    @defer.inlineCallbacks
-    def register_user(
+    async def register_user(
         self,
         localpart=None,
         password_hash=None,
@@ -169,18 +167,18 @@ class RegistrationHandler(BaseHandler):
             by_admin (bool): True if this registration is being made via the
               admin api, otherwise False.
         Returns:
-            Deferred[str]: user_id
+            str: user_id
         Raises:
             SynapseError if there was a problem registering.
         """
-        yield self.check_registration_ratelimit(address)
+        self.check_registration_ratelimit(address)
 
         # do not check_auth_blocking if the call is coming through the Admin API
         if not by_admin:
-            yield self.auth.check_auth_blocking(threepid=threepid)
+            await self.auth.check_auth_blocking(threepid=threepid)
 
         if localpart is not None:
-            yield self.check_username(localpart, guest_access_token=guest_access_token)
+            await self.check_username(localpart, guest_access_token=guest_access_token)
 
             was_guest = guest_access_token is not None
 
@@ -194,7 +192,7 @@ class RegistrationHandler(BaseHandler):
             elif default_display_name is None:
                 default_display_name = localpart
 
-            yield self.register_with_store(
+            await self.register_with_store(
                 user_id=user_id,
                 password_hash=password_hash,
                 was_guest=was_guest,
@@ -206,11 +204,9 @@ class RegistrationHandler(BaseHandler):
             )
 
             if self.hs.config.user_directory_search_all_users:
-                profile = yield self.store.get_profileinfo(localpart)
-                yield defer.ensureDeferred(
-                    self.user_directory_handler.handle_local_profile_change(
-                        user_id, profile
-                    )
+                profile = await self.store.get_profileinfo(localpart)
+                await self.user_directory_handler.handle_local_profile_change(
+                    user_id, profile
                 )
 
         else:
@@ -222,14 +218,14 @@ class RegistrationHandler(BaseHandler):
                 if fail_count > 10:
                     raise SynapseError(500, "Unable to find a suitable guest user ID")
 
-                localpart = yield self._generate_user_id()
+                localpart = await self._generate_user_id()
                 user = UserID(localpart, self.hs.hostname)
                 user_id = user.to_string()
-                yield self.check_user_id_not_appservice_exclusive(user_id)
+                self.check_user_id_not_appservice_exclusive(user_id)
                 if default_display_name is None:
                     default_display_name = localpart
                 try:
-                    yield self.register_with_store(
+                    await self.register_with_store(
                         user_id=user_id,
                         password_hash=password_hash,
                         make_guest=make_guest,
@@ -252,7 +248,7 @@ class RegistrationHandler(BaseHandler):
                     user_id,
                 )
             else:
-                yield defer.ensureDeferred(self._auto_join_rooms(user_id))
+                await self._auto_join_rooms(user_id)
         else:
             logger.info(
                 "Skipping auto-join for %s because consent is required at registration",
@@ -270,7 +266,7 @@ class RegistrationHandler(BaseHandler):
             }
 
             # Bind email to new account
-            yield self._register_email_threepid(user_id, threepid_dict, None)
+            await self._register_email_threepid(user_id, threepid_dict, None)
 
         return user_id
 
@@ -335,8 +331,7 @@ class RegistrationHandler(BaseHandler):
         """
         await self._auto_join_rooms(user_id)
 
-    @defer.inlineCallbacks
-    def appservice_register(self, user_localpart, as_token):
+    async def appservice_register(self, user_localpart, as_token):
         user = UserID(user_localpart, self.hs.hostname)
         user_id = user.to_string()
         service = self.store.get_app_service_by_token(as_token)
@@ -351,11 +346,9 @@ class RegistrationHandler(BaseHandler):
 
         service_id = service.id if service.is_exclusive_user(user_id) else None
 
-        yield self.check_user_id_not_appservice_exclusive(
-            user_id, allowed_appservice=service
-        )
+        self.check_user_id_not_appservice_exclusive(user_id, allowed_appservice=service)
 
-        yield self.register_with_store(
+        await self.register_with_store(
             user_id=user_id,
             password_hash="",
             appservice_id=service_id,
@@ -387,13 +380,12 @@ class RegistrationHandler(BaseHandler):
                     errcode=Codes.EXCLUSIVE,
                 )
 
-    @defer.inlineCallbacks
-    def _generate_user_id(self):
+    async def _generate_user_id(self):
         if self._next_generated_user_id is None:
-            with (yield self._generate_user_id_linearizer.queue(())):
+            with await self._generate_user_id_linearizer.queue(()):
                 if self._next_generated_user_id is None:
                     self._next_generated_user_id = (
-                        yield self.store.find_next_generated_user_id_localpart()
+                        await self.store.find_next_generated_user_id_localpart()
                     )
 
         id = self._next_generated_user_id
@@ -496,8 +488,9 @@ class RegistrationHandler(BaseHandler):
                 user_type=user_type,
             )
 
-    @defer.inlineCallbacks
-    def register_device(self, user_id, device_id, initial_display_name, is_guest=False):
+    async def register_device(
+        self, user_id, device_id, initial_display_name, is_guest=False
+    ):
         """Register a device for a user and generate an access token.
 
         The access token will be limited by the homeserver's session_lifetime config.
@@ -511,11 +504,11 @@ class RegistrationHandler(BaseHandler):
             is_guest (bool): Whether this is a guest account
 
         Returns:
-            defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
+            tuple[str, str]: Tuple of device ID and access token
         """
 
         if self.hs.config.worker_app:
-            r = yield self._register_device_client(
+            r = await self._register_device_client(
                 user_id=user_id,
                 device_id=device_id,
                 initial_display_name=initial_display_name,
@@ -531,7 +524,7 @@ class RegistrationHandler(BaseHandler):
                 )
             valid_until_ms = self.clock.time_msec() + self.session_lifetime
 
-        device_id = yield self.device_handler.check_device_registered(
+        device_id = await self.device_handler.check_device_registered(
             user_id, device_id, initial_display_name
         )
         if is_guest:
@@ -540,10 +533,8 @@ class RegistrationHandler(BaseHandler):
                 user_id, ["guest = true"]
             )
         else:
-            access_token = yield defer.ensureDeferred(
-                self._auth_handler.get_access_token_for_user_id(
-                    user_id, device_id=device_id, valid_until_ms=valid_until_ms
-                )
+            access_token = await self._auth_handler.get_access_token_for_user_id(
+                user_id, device_id=device_id, valid_until_ms=valid_until_ms
             )
 
         return (device_id, access_token)
@@ -594,8 +585,7 @@ class RegistrationHandler(BaseHandler):
         await self.store.user_set_consent_version(user_id, consent_version)
         await self.post_consent_actions(user_id)
 
-    @defer.inlineCallbacks
-    def _register_email_threepid(self, user_id, threepid, token):
+    async def _register_email_threepid(self, user_id, threepid, token):
         """Add an email address as a 3pid identifier
 
         Also adds an email pusher for the email address, if configured in the
@@ -608,8 +598,6 @@ class RegistrationHandler(BaseHandler):
             threepid (object): m.login.email.identity auth response
             token (str|None): access_token for the user, or None if not logged
                 in.
-        Returns:
-            defer.Deferred:
         """
         reqd = ("medium", "address", "validated_at")
         if any(x not in threepid for x in reqd):
@@ -617,13 +605,8 @@ class RegistrationHandler(BaseHandler):
             logger.info("Can't add incomplete 3pid")
             return
 
-        yield defer.ensureDeferred(
-            self._auth_handler.add_threepid(
-                user_id,
-                threepid["medium"],
-                threepid["address"],
-                threepid["validated_at"],
-            )
+        await self._auth_handler.add_threepid(
+            user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
         )
 
         # And we add an email pusher for them by default, but only
@@ -639,10 +622,10 @@ class RegistrationHandler(BaseHandler):
             # It would really make more sense for this to be passed
             # up when the access token is saved, but that's quite an
             # invasive change I'd rather do separately.
-            user_tuple = yield self.store.get_user_by_access_token(token)
+            user_tuple = await self.store.get_user_by_access_token(token)
             token_id = user_tuple["token_id"]
 
-            yield self.pusher_pool.add_pusher(
+            await self.pusher_pool.add_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="email",
@@ -654,8 +637,7 @@ class RegistrationHandler(BaseHandler):
                 data={},
             )
 
-    @defer.inlineCallbacks
-    def _register_msisdn_threepid(self, user_id, threepid):
+    async def _register_msisdn_threepid(self, user_id, threepid):
         """Add a phone number as a 3pid identifier
 
         Must be called on master.
@@ -663,8 +645,6 @@ class RegistrationHandler(BaseHandler):
         Args:
             user_id (str): id of user
             threepid (object): m.login.msisdn auth response
-        Returns:
-            defer.Deferred:
         """
         try:
             assert_params_in_dict(threepid, ["medium", "address", "validated_at"])
@@ -675,11 +655,6 @@ class RegistrationHandler(BaseHandler):
                 return None
             raise
 
-        yield defer.ensureDeferred(
-            self._auth_handler.add_threepid(
-                user_id,
-                threepid["medium"],
-                threepid["address"],
-                threepid["validated_at"],
-            )
+        await self._auth_handler.add_threepid(
+            user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
         )
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index d678c0eb9b..ecdf1ad69f 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -128,8 +128,12 @@ class ModuleApi(object):
         Returns:
             Deferred[str]: user_id
         """
-        return self._hs.get_registration_handler().register_user(
-            localpart=localpart, default_display_name=displayname, bind_emails=emails
+        return defer.ensureDeferred(
+            self._hs.get_registration_handler().register_user(
+                localpart=localpart,
+                default_display_name=displayname,
+                bind_emails=emails,
+            )
         )
 
     def register_device(self, user_id, device_id=None, initial_display_name=None):