summary refs log tree commit diff
path: root/tests/handlers/test_register.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_register.py')
-rw-r--r--tests/handlers/test_register.py33
1 files changed, 14 insertions, 19 deletions
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 54eeec228e..a04234829f 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 from typing import Any, Collection, List, Optional, Tuple
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -38,7 +38,6 @@ from synapse.types import (
 )
 from synapse.util import Clock
 
-from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 from tests.utils import mock_getRawHeaders
 
@@ -203,24 +202,22 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
     @override_config({"limit_usage_by_mau": True})
     def test_get_or_create_user_mau_not_blocked(self) -> None:
-        self.store.count_monthly_users = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
+        self.store.count_monthly_users = AsyncMock(  # type: ignore[assignment]
+            return_value=self.hs.config.server.max_mau_value - 1
         )
         # Ensure does not throw exception
         self.get_success(self.get_or_create_user(self.requester, "c", "User"))
 
     @override_config({"limit_usage_by_mau": True})
     def test_get_or_create_user_mau_blocked(self) -> None:
-        self.store.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.lots_of_users)
-        )
+        self.store.get_monthly_active_count = AsyncMock(return_value=self.lots_of_users)
         self.get_failure(
             self.get_or_create_user(self.requester, "b", "display_name"),
             ResourceLimitError,
         )
 
-        self.store.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.hs.config.server.max_mau_value)
+        self.store.get_monthly_active_count = AsyncMock(
+            return_value=self.hs.config.server.max_mau_value
         )
         self.get_failure(
             self.get_or_create_user(self.requester, "b", "display_name"),
@@ -229,15 +226,13 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
     @override_config({"limit_usage_by_mau": True})
     def test_register_mau_blocked(self) -> None:
-        self.store.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.lots_of_users)
-        )
+        self.store.get_monthly_active_count = AsyncMock(return_value=self.lots_of_users)
         self.get_failure(
             self.handler.register_user(localpart="local_part"), ResourceLimitError
         )
 
-        self.store.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.hs.config.server.max_mau_value)
+        self.store.get_monthly_active_count = AsyncMock(
+            return_value=self.hs.config.server.max_mau_value
         )
         self.get_failure(
             self.handler.register_user(localpart="local_part"), ResourceLimitError
@@ -292,7 +287,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
     @override_config({"auto_join_rooms": ["#room:test"]})
     def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self) -> None:
         room_alias_str = "#room:test"
-        self.store.is_real_user = Mock(return_value=make_awaitable(False))
+        self.store.is_real_user = AsyncMock(return_value=False)
         user_id = self.get_success(self.handler.register_user(localpart="support"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         self.assertEqual(len(rooms), 0)
@@ -304,8 +299,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
     def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None:
         room_alias_str = "#room:test"
 
-        self.store.count_real_users = Mock(return_value=make_awaitable(1))  # type: ignore[assignment]
-        self.store.is_real_user = Mock(return_value=make_awaitable(True))
+        self.store.count_real_users = AsyncMock(return_value=1)  # type: ignore[assignment]
+        self.store.is_real_user = AsyncMock(return_value=True)
         user_id = self.get_success(self.handler.register_user(localpart="real"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         directory_handler = self.hs.get_directory_handler()
@@ -319,8 +314,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
     def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(
         self,
     ) -> None:
-        self.store.count_real_users = Mock(return_value=make_awaitable(2))  # type: ignore[assignment]
-        self.store.is_real_user = Mock(return_value=make_awaitable(True))
+        self.store.count_real_users = AsyncMock(return_value=2)  # type: ignore[assignment]
+        self.store.is_real_user = AsyncMock(return_value=True)
         user_id = self.get_success(self.handler.register_user(localpart="real"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         self.assertEqual(len(rooms), 0)