summary refs log tree commit diff
path: root/tests/handlers/test_password_providers.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_password_providers.py')
-rw-r--r--tests/handlers/test_password_providers.py51
1 files changed, 23 insertions, 28 deletions
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 394006f5f3..4496370c3f 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -16,7 +16,7 @@
 
 from http import HTTPStatus
 from typing import Any, Dict, List, Optional, Type, Union
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -32,7 +32,6 @@ from synapse.util import Clock
 
 from tests import unittest
 from tests.server import FakeChannel
-from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 
 # Login flows we expect to appear in the list after the normal ones.
@@ -187,7 +186,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
 
         # check_password must return an awaitable
-        mock_password_provider.check_password.return_value = make_awaitable(True)
+        mock_password_provider.check_password = AsyncMock(return_value=True)
         channel = self._send_password_login("u", "p")
         self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
         self.assertEqual("@u:test", channel.json_body["user_id"])
@@ -209,13 +208,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         """UI Auth should delegate correctly to the password provider"""
 
         # log in twice, to get two devices
-        mock_password_provider.check_password.return_value = make_awaitable(True)
+        mock_password_provider.check_password = AsyncMock(return_value=True)
         tok1 = self.login("u", "p")
         self.login("u", "p", device_id="dev2")
         mock_password_provider.reset_mock()
 
         # have the auth provider deny the request to start with
-        mock_password_provider.check_password.return_value = make_awaitable(False)
+        mock_password_provider.check_password = AsyncMock(return_value=False)
 
         # make the initial request which returns a 401
         session = self._start_delete_device_session(tok1, "dev2")
@@ -229,7 +228,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.reset_mock()
 
         # Finally, check the request goes through when we allow it
-        mock_password_provider.check_password.return_value = make_awaitable(True)
+        mock_password_provider.check_password = AsyncMock(return_value=True)
         channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
         self.assertEqual(channel.code, 200)
         mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
@@ -243,7 +242,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.register_user("localuser", "localpass")
 
         # check_password must return an awaitable
-        mock_password_provider.check_password.return_value = make_awaitable(False)
+        mock_password_provider.check_password = AsyncMock(return_value=False)
         channel = self._send_password_login("u", "p")
         self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
 
@@ -260,7 +259,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.register_user("localuser", "localpass")
 
         # have the auth provider deny the request
-        mock_password_provider.check_password.return_value = make_awaitable(False)
+        mock_password_provider.check_password = AsyncMock(return_value=False)
 
         # log in twice, to get two devices
         tok1 = self.login("localuser", "localpass")
@@ -303,7 +302,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.register_user("localuser", "localpass")
 
         # check_password must return an awaitable
-        mock_password_provider.check_password.return_value = make_awaitable(False)
+        mock_password_provider.check_password = AsyncMock(return_value=False)
         channel = self._send_password_login("localuser", "localpass")
         self.assertEqual(channel.code, 403)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@@ -325,7 +324,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.register_user("localuser", "localpass")
 
         # allow login via the auth provider
-        mock_password_provider.check_password.return_value = make_awaitable(True)
+        mock_password_provider.check_password = AsyncMock(return_value=True)
 
         # log in twice, to get two devices
         tok1 = self.login("localuser", "p")
@@ -342,7 +341,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.check_password.assert_not_called()
 
         # now try deleting with the local password
-        mock_password_provider.check_password.return_value = make_awaitable(False)
+        mock_password_provider.check_password = AsyncMock(return_value=False)
         channel = self._authed_delete_device(
             tok1, "dev2", session, "localuser", "localpass"
         )
@@ -396,9 +395,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
         mock_password_provider.check_auth.assert_not_called()
 
-        mock_password_provider.check_auth.return_value = make_awaitable(
-            ("@user:test", None)
-        )
+        mock_password_provider.check_auth = AsyncMock(return_value=("@user:test", None))
         channel = self._send_login("test.login_type", "u", test_field="y")
         self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
         self.assertEqual("@user:test", channel.json_body["user_id"])
@@ -447,9 +444,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.reset_mock()
 
         # right params, but authing as the wrong user
-        mock_password_provider.check_auth.return_value = make_awaitable(
-            ("@user:test", None)
-        )
+        mock_password_provider.check_auth = AsyncMock(return_value=("@user:test", None))
         body["auth"]["test_field"] = "foo"
         channel = self._delete_device(tok1, "dev2", body)
         self.assertEqual(channel.code, 403)
@@ -460,8 +455,8 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.reset_mock()
 
         # and finally, succeed
-        mock_password_provider.check_auth.return_value = make_awaitable(
-            ("@localuser:test", None)
+        mock_password_provider.check_auth = AsyncMock(
+            return_value=("@localuser:test", None)
         )
         channel = self._delete_device(tok1, "dev2", body)
         self.assertEqual(channel.code, 200)
@@ -478,10 +473,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.custom_auth_provider_callback_test_body()
 
     def custom_auth_provider_callback_test_body(self) -> None:
-        callback = Mock(return_value=make_awaitable(None))
+        callback = AsyncMock(return_value=None)
 
-        mock_password_provider.check_auth.return_value = make_awaitable(
-            ("@user:test", callback)
+        mock_password_provider.check_auth = AsyncMock(
+            return_value=("@user:test", callback)
         )
         channel = self._send_login("test.login_type", "u", test_field="y")
         self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
@@ -616,8 +611,8 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         login is disabled"""
         # register the user and log in twice via the test login type to get two devices,
         self.register_user("localuser", "localpass")
-        mock_password_provider.check_auth.return_value = make_awaitable(
-            ("@localuser:test", None)
+        mock_password_provider.check_auth = AsyncMock(
+            return_value=("@localuser:test", None)
         )
         channel = self._send_login("test.login_type", "localuser", test_field="")
         self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
@@ -835,11 +830,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             username: The username to use for the test.
             registration: Whether to test with registration URLs.
         """
-        self.hs.get_identity_handler().send_threepid_validation = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(0),
+        self.hs.get_identity_handler().send_threepid_validation = AsyncMock(  # type: ignore[assignment]
+            return_value=0
         )
 
-        m = Mock(return_value=make_awaitable(False))
+        m = AsyncMock(return_value=False)
         self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
 
         self.register_user(username, "password")
@@ -869,7 +864,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
         m.assert_called_once_with("email", "foo@test.com", registration)
 
-        m = Mock(return_value=make_awaitable(True))
+        m = AsyncMock(return_value=True)
         self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
 
         channel = self.make_request(