summary refs log tree commit diff
path: root/tests/handlers/test_password_providers.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_password_providers.py')
-rw-r--r--tests/handlers/test_password_providers.py76
1 files changed, 75 insertions, 1 deletions
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 94809cb8be..4740dd0a65 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -21,13 +21,15 @@ from twisted.internet import defer
 
 import synapse
 from synapse.api.constants import LoginType
+from synapse.api.errors import Codes
 from synapse.handlers.auth import load_legacy_password_auth_providers
 from synapse.module_api import ModuleApi
-from synapse.rest.client import devices, login, logout, register
+from synapse.rest.client import account, devices, login, logout, register
 from synapse.types import JsonDict, UserID
 
 from tests import unittest
 from tests.server import FakeChannel
+from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 
 # (possibly experimental) login flows we expect to appear in the list after the normal
@@ -158,6 +160,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         devices.register_servlets,
         logout.register_servlets,
         register.register_servlets,
+        account.register_servlets,
     ]
 
     def setUp(self):
@@ -803,6 +806,77 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         # Check that the callback has been called.
         m.assert_called_once()
 
+    # Set some email configuration so the test doesn't fail because of its absence.
+    @override_config({"email": {"notif_from": "noreply@test"}})
+    def test_3pid_allowed(self):
+        """Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse
+        to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind
+        the 3PID. Also checks that the module is passed a boolean indicating whether the
+        user to bind this 3PID to is currently registering.
+        """
+        self._test_3pid_allowed("rin", False)
+        self._test_3pid_allowed("kitay", True)
+
+    def _test_3pid_allowed(self, username: str, registration: bool):
+        """Tests that the "is_3pid_allowed" module callback is called correctly, using
+        either /register or /account URLs depending on the arguments.
+
+        Args:
+            username: The username to use for the test.
+            registration: Whether to test with registration URLs.
+        """
+        self.hs.get_identity_handler().send_threepid_validation = Mock(
+            return_value=make_awaitable(0),
+        )
+
+        m = Mock(return_value=make_awaitable(False))
+        self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
+
+        self.register_user(username, "password")
+        tok = self.login(username, "password")
+
+        if registration:
+            url = "/register/email/requestToken"
+        else:
+            url = "/account/3pid/email/requestToken"
+
+        channel = self.make_request(
+            "POST",
+            url,
+            {
+                "client_secret": "foo",
+                "email": "foo@test.com",
+                "send_attempt": 0,
+            },
+            access_token=tok,
+        )
+        self.assertEqual(channel.code, 403, channel.result)
+        self.assertEqual(
+            channel.json_body["errcode"],
+            Codes.THREEPID_DENIED,
+            channel.json_body,
+        )
+
+        m.assert_called_once_with("email", "foo@test.com", registration)
+
+        m = Mock(return_value=make_awaitable(True))
+        self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
+
+        channel = self.make_request(
+            "POST",
+            url,
+            {
+                "client_secret": "foo",
+                "email": "bar@test.com",
+                "send_attempt": 0,
+            },
+            access_token=tok,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+        self.assertIn("sid", channel.json_body)
+
+        m.assert_called_once_with("email", "bar@test.com", registration)
+
     def _setup_get_username_for_registration(self) -> Mock:
         """Registers a get_username_for_registration callback that appends "-foo" to the
         username the client is trying to register.