summary refs log tree commit diff
path: root/tests/handlers/test_register.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_register.py')
-rw-r--r--tests/handlers/test_register.py29
1 files changed, 14 insertions, 15 deletions
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index f1dc51d6c9..1b7935cef2 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -175,7 +175,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         room_alias_str = "#room:test"
         self.hs.config.auto_join_rooms = [room_alias_str]
 
-        self.store.is_real_user = Mock(return_value=False)
+        self.store.is_real_user = Mock(return_value=defer.succeed(False))
         user_id = self.get_success(self.handler.register_user(localpart="support"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         self.assertEqual(len(rooms), 0)
@@ -187,8 +187,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         room_alias_str = "#room:test"
         self.hs.config.auto_join_rooms = [room_alias_str]
 
-        self.store.count_real_users = Mock(return_value=1)
-        self.store.is_real_user = Mock(return_value=True)
+        self.store.count_real_users = Mock(return_value=defer.succeed(1))
+        self.store.is_real_user = Mock(return_value=defer.succeed(True))
         user_id = self.get_success(self.handler.register_user(localpart="real"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         directory_handler = self.hs.get_handlers().directory_handler
@@ -202,8 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         room_alias_str = "#room:test"
         self.hs.config.auto_join_rooms = [room_alias_str]
 
-        self.store.count_real_users = Mock(return_value=2)
-        self.store.is_real_user = Mock(return_value=True)
+        self.store.count_real_users = Mock(return_value=defer.succeed(2))
+        self.store.is_real_user = Mock(return_value=defer.succeed(True))
         user_id = self.get_success(self.handler.register_user(localpart="real"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         self.assertEqual(len(rooms), 0)
@@ -256,8 +256,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
             self.handler.register_user(localpart=invalid_user_id), SynapseError
         )
 
-    @defer.inlineCallbacks
-    def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
+    async def get_or_create_user(
+        self, requester, localpart, displayname, password_hash=None
+    ):
         """Creates a new user if the user does not exist,
         else revokes all previous access tokens and generates a new one.
 
@@ -272,11 +273,11 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         """
         if localpart is None:
             raise SynapseError(400, "Request must include user id")
-        yield self.hs.get_auth().check_auth_blocking()
+        await self.hs.get_auth().check_auth_blocking()
         need_register = True
 
         try:
-            yield self.handler.check_username(localpart)
+            await self.handler.check_username(localpart)
         except SynapseError as e:
             if e.errcode == Codes.USER_IN_USE:
                 need_register = False
@@ -288,23 +289,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         token = self.macaroon_generator.generate_access_token(user_id)
 
         if need_register:
-            yield self.handler.register_with_store(
+            await self.handler.register_with_store(
                 user_id=user_id,
                 password_hash=password_hash,
                 create_profile_with_displayname=user.localpart,
             )
         else:
-            yield defer.ensureDeferred(
-                self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
-            )
+            await self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
 
-        yield self.store.add_access_token_to_user(
+        await self.store.add_access_token_to_user(
             user_id=user_id, token=token, device_id=None, valid_until_ms=None
         )
 
         if displayname is not None:
             # logger.info("setting user display name: %s -> %s", user_id, displayname)
-            yield self.hs.get_profile_handler().set_displayname(
+            await self.hs.get_profile_handler().set_displayname(
                 user, requester, displayname, by_admin=True
             )