diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 94809cb8be..4740dd0a65 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -21,13 +21,15 @@ from twisted.internet import defer
import synapse
from synapse.api.constants import LoginType
+from synapse.api.errors import Codes
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.module_api import ModuleApi
-from synapse.rest.client import devices, login, logout, register
+from synapse.rest.client import account, devices, login, logout, register
from synapse.types import JsonDict, UserID
from tests import unittest
from tests.server import FakeChannel
+from tests.test_utils import make_awaitable
from tests.unittest import override_config
# (possibly experimental) login flows we expect to appear in the list after the normal
@@ -158,6 +160,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
devices.register_servlets,
logout.register_servlets,
register.register_servlets,
+ account.register_servlets,
]
def setUp(self):
@@ -803,6 +806,77 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# Check that the callback has been called.
m.assert_called_once()
+ # Set some email configuration so the test doesn't fail because of its absence.
+ @override_config({"email": {"notif_from": "noreply@test"}})
+ def test_3pid_allowed(self):
+ """Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse
+ to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind
+ the 3PID. Also checks that the module is passed a boolean indicating whether the
+ user to bind this 3PID to is currently registering.
+ """
+ self._test_3pid_allowed("rin", False)
+ self._test_3pid_allowed("kitay", True)
+
+ def _test_3pid_allowed(self, username: str, registration: bool):
+ """Tests that the "is_3pid_allowed" module callback is called correctly, using
+ either /register or /account URLs depending on the arguments.
+
+ Args:
+ username: The username to use for the test.
+ registration: Whether to test with registration URLs.
+ """
+ self.hs.get_identity_handler().send_threepid_validation = Mock(
+ return_value=make_awaitable(0),
+ )
+
+ m = Mock(return_value=make_awaitable(False))
+ self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
+
+ self.register_user(username, "password")
+ tok = self.login(username, "password")
+
+ if registration:
+ url = "/register/email/requestToken"
+ else:
+ url = "/account/3pid/email/requestToken"
+
+ channel = self.make_request(
+ "POST",
+ url,
+ {
+ "client_secret": "foo",
+ "email": "foo@test.com",
+ "send_attempt": 0,
+ },
+ access_token=tok,
+ )
+ self.assertEqual(channel.code, 403, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"],
+ Codes.THREEPID_DENIED,
+ channel.json_body,
+ )
+
+ m.assert_called_once_with("email", "foo@test.com", registration)
+
+ m = Mock(return_value=make_awaitable(True))
+ self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
+
+ channel = self.make_request(
+ "POST",
+ url,
+ {
+ "client_secret": "foo",
+ "email": "bar@test.com",
+ "send_attempt": 0,
+ },
+ access_token=tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertIn("sid", channel.json_body)
+
+ m.assert_called_once_with("email", "bar@test.com", registration)
+
def _setup_get_username_for_registration(self) -> Mock:
"""Registers a get_username_for_registration callback that appends "-foo" to the
username the client is trying to register.
|