summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11854.feature1
-rw-r--r--docs/modules/password_auth_provider_callbacks.md19
-rw-r--r--synapse/handlers/auth.py44
-rw-r--r--synapse/module_api/__init__.py3
-rw-r--r--synapse/rest/client/account.py4
-rw-r--r--synapse/rest/client/register.py8
-rw-r--r--synapse/util/threepids.py13
-rw-r--r--tests/handlers/test_password_providers.py76
8 files changed, 161 insertions, 7 deletions
diff --git a/changelog.d/11854.feature b/changelog.d/11854.feature
new file mode 100644
index 0000000000..975e95bc52
--- /dev/null
+++ b/changelog.d/11854.feature
@@ -0,0 +1 @@
+Add a callback to allow modules to allow or forbid a 3PID (email address, phone number) from being associated to a local account.
diff --git a/docs/modules/password_auth_provider_callbacks.md b/docs/modules/password_auth_provider_callbacks.md
index 3697e3782e..88b59bb09e 100644
--- a/docs/modules/password_auth_provider_callbacks.md
+++ b/docs/modules/password_auth_provider_callbacks.md
@@ -166,6 +166,25 @@ any of the subsequent implementations of this callback. If every callback return
 the username provided by the user is used, if any (otherwise one is automatically
 generated).
 
+## `is_3pid_allowed`
+
+_First introduced in Synapse v1.53.0_
+
+```python
+async def is_3pid_allowed(self, medium: str, address: str, registration: bool) -> bool
+```
+
+Called when attempting to bind a third-party identifier (i.e. an email address or a phone
+number). The module is given the medium of the third-party identifier (which is `email` if
+the identifier is an email address, or `msisdn` if the identifier is a phone number) and
+its address, as well as a boolean indicating whether the attempt to bind is happening as
+part of registering a new user. The module must return a boolean indicating whether the
+identifier can be allowed to be bound to an account on the local homeserver.
+
+If multiple modules implement this callback, they will be considered in order. If a
+callback returns `True`, Synapse falls through to the next one. The value of the first
+callback that does not return `True` will be used. If this happens, Synapse will not call
+any of the subsequent implementations of this callback.
 
 ## Example
 
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index e32c93e234..6959d1aa7e 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -2064,6 +2064,7 @@ GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
     [JsonDict, JsonDict],
     Awaitable[Optional[str]],
 ]
+IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
 
 
 class PasswordAuthProvider:
@@ -2079,6 +2080,7 @@ class PasswordAuthProvider:
         self.get_username_for_registration_callbacks: List[
             GET_USERNAME_FOR_REGISTRATION_CALLBACK
         ] = []
+        self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
 
         # Mapping from login type to login parameters
         self._supported_login_types: Dict[str, Iterable[str]] = {}
@@ -2090,6 +2092,7 @@ class PasswordAuthProvider:
         self,
         check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
         on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
+        is_3pid_allowed: Optional[IS_3PID_ALLOWED_CALLBACK] = None,
         auth_checkers: Optional[
             Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
         ] = None,
@@ -2145,6 +2148,9 @@ class PasswordAuthProvider:
                 get_username_for_registration,
             )
 
+        if is_3pid_allowed is not None:
+            self.is_3pid_allowed_callbacks.append(is_3pid_allowed)
+
     def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
         """Get the login types supported by this password provider
 
@@ -2343,3 +2349,41 @@ class PasswordAuthProvider:
                 raise SynapseError(code=500, msg="Internal Server Error")
 
         return None
+
+    async def is_3pid_allowed(
+        self,
+        medium: str,
+        address: str,
+        registration: bool,
+    ) -> bool:
+        """Check if the user can be allowed to bind a 3PID on this homeserver.
+
+        Args:
+            medium: The medium of the 3PID.
+            address: The address of the 3PID.
+            registration: Whether the 3PID is being bound when registering a new user.
+
+        Returns:
+            Whether the 3PID is allowed to be bound on this homeserver
+        """
+        for callback in self.is_3pid_allowed_callbacks:
+            try:
+                res = await callback(medium, address, registration)
+
+                if res is False:
+                    return res
+                elif not isinstance(res, bool):
+                    # mypy complains that this line is unreachable because it assumes the
+                    # data returned by the module fits the expected type. We just want
+                    # to make sure this is the case.
+                    logger.warning(  # type: ignore[unreachable]
+                        "Ignoring non-string value returned by"
+                        " is_3pid_allowed callback %s: %s",
+                        callback,
+                        res,
+                    )
+            except Exception as e:
+                logger.error("Module raised an exception in is_3pid_allowed: %s", e)
+                raise SynapseError(code=500, msg="Internal Server Error")
+
+        return True
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 29fbc73c97..a91a7fa3ce 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -72,6 +72,7 @@ from synapse.handlers.auth import (
     CHECK_3PID_AUTH_CALLBACK,
     CHECK_AUTH_CALLBACK,
     GET_USERNAME_FOR_REGISTRATION_CALLBACK,
+    IS_3PID_ALLOWED_CALLBACK,
     ON_LOGGED_OUT_CALLBACK,
     AuthHandler,
 )
@@ -312,6 +313,7 @@ class ModuleApi:
         auth_checkers: Optional[
             Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
         ] = None,
+        is_3pid_allowed: Optional[IS_3PID_ALLOWED_CALLBACK] = None,
         get_username_for_registration: Optional[
             GET_USERNAME_FOR_REGISTRATION_CALLBACK
         ] = None,
@@ -323,6 +325,7 @@ class ModuleApi:
         return self._password_auth_provider.register_password_auth_provider_callbacks(
             check_3pid_auth=check_3pid_auth,
             on_logged_out=on_logged_out,
+            is_3pid_allowed=is_3pid_allowed,
             auth_checkers=auth_checkers,
             get_username_for_registration=get_username_for_registration,
         )
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index 6b272658fc..cfa2aee76d 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -385,7 +385,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
         send_attempt = body["send_attempt"]
         next_link = body.get("next_link")  # Optional param
 
-        if not check_3pid_allowed(self.hs, "email", email):
+        if not await check_3pid_allowed(self.hs, "email", email):
             raise SynapseError(
                 403,
                 "Your email domain is not authorized on this server",
@@ -468,7 +468,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
 
         msisdn = phone_number_to_msisdn(country, phone_number)
 
-        if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+        if not await check_3pid_allowed(self.hs, "msisdn", msisdn):
             raise SynapseError(
                 403,
                 "Account phone numbers are not authorized on this server",
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index c283313e8d..c965e2bda2 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -112,7 +112,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
         send_attempt = body["send_attempt"]
         next_link = body.get("next_link")  # Optional param
 
-        if not check_3pid_allowed(self.hs, "email", email):
+        if not await check_3pid_allowed(self.hs, "email", email, registration=True):
             raise SynapseError(
                 403,
                 "Your email domain is not authorized to register on this server",
@@ -192,7 +192,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
 
         msisdn = phone_number_to_msisdn(country, phone_number)
 
-        if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+        if not await check_3pid_allowed(self.hs, "msisdn", msisdn, registration=True):
             raise SynapseError(
                 403,
                 "Phone numbers are not authorized to register on this server",
@@ -616,7 +616,9 @@ class RegisterRestServlet(RestServlet):
                     medium = auth_result[login_type]["medium"]
                     address = auth_result[login_type]["address"]
 
-                    if not check_3pid_allowed(self.hs, medium, address):
+                    if not await check_3pid_allowed(
+                        self.hs, medium, address, registration=True
+                    ):
                         raise SynapseError(
                             403,
                             "Third party identifiers (email/phone numbers)"
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 389adf00f6..1e9c2faa64 100644
--- a/synapse/util/threepids.py
+++ b/synapse/util/threepids.py
@@ -32,7 +32,12 @@ logger = logging.getLogger(__name__)
 MAX_EMAIL_ADDRESS_LENGTH = 500
 
 
-def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool:
+async def check_3pid_allowed(
+    hs: "HomeServer",
+    medium: str,
+    address: str,
+    registration: bool = False,
+) -> bool:
     """Checks whether a given format of 3PID is allowed to be used on this HS
 
     Args:
@@ -40,9 +45,15 @@ def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool:
         medium: 3pid medium - e.g. email, msisdn
         address: address within that medium (e.g. "wotan@matrix.org")
             msisdns need to first have been canonicalised
+        registration: whether we want to bind the 3PID as part of registering a new user.
+
     Returns:
         bool: whether the 3PID medium/address is allowed to be added to this HS
     """
+    if not await hs.get_password_auth_provider().is_3pid_allowed(
+        medium, address, registration
+    ):
+        return False
 
     if hs.config.registration.allowed_local_3pids:
         for constraint in hs.config.registration.allowed_local_3pids:
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 94809cb8be..4740dd0a65 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -21,13 +21,15 @@ from twisted.internet import defer
 
 import synapse
 from synapse.api.constants import LoginType
+from synapse.api.errors import Codes
 from synapse.handlers.auth import load_legacy_password_auth_providers
 from synapse.module_api import ModuleApi
-from synapse.rest.client import devices, login, logout, register
+from synapse.rest.client import account, devices, login, logout, register
 from synapse.types import JsonDict, UserID
 
 from tests import unittest
 from tests.server import FakeChannel
+from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 
 # (possibly experimental) login flows we expect to appear in the list after the normal
@@ -158,6 +160,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         devices.register_servlets,
         logout.register_servlets,
         register.register_servlets,
+        account.register_servlets,
     ]
 
     def setUp(self):
@@ -803,6 +806,77 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         # Check that the callback has been called.
         m.assert_called_once()
 
+    # Set some email configuration so the test doesn't fail because of its absence.
+    @override_config({"email": {"notif_from": "noreply@test"}})
+    def test_3pid_allowed(self):
+        """Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse
+        to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind
+        the 3PID. Also checks that the module is passed a boolean indicating whether the
+        user to bind this 3PID to is currently registering.
+        """
+        self._test_3pid_allowed("rin", False)
+        self._test_3pid_allowed("kitay", True)
+
+    def _test_3pid_allowed(self, username: str, registration: bool):
+        """Tests that the "is_3pid_allowed" module callback is called correctly, using
+        either /register or /account URLs depending on the arguments.
+
+        Args:
+            username: The username to use for the test.
+            registration: Whether to test with registration URLs.
+        """
+        self.hs.get_identity_handler().send_threepid_validation = Mock(
+            return_value=make_awaitable(0),
+        )
+
+        m = Mock(return_value=make_awaitable(False))
+        self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
+
+        self.register_user(username, "password")
+        tok = self.login(username, "password")
+
+        if registration:
+            url = "/register/email/requestToken"
+        else:
+            url = "/account/3pid/email/requestToken"
+
+        channel = self.make_request(
+            "POST",
+            url,
+            {
+                "client_secret": "foo",
+                "email": "foo@test.com",
+                "send_attempt": 0,
+            },
+            access_token=tok,
+        )
+        self.assertEqual(channel.code, 403, channel.result)
+        self.assertEqual(
+            channel.json_body["errcode"],
+            Codes.THREEPID_DENIED,
+            channel.json_body,
+        )
+
+        m.assert_called_once_with("email", "foo@test.com", registration)
+
+        m = Mock(return_value=make_awaitable(True))
+        self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
+
+        channel = self.make_request(
+            "POST",
+            url,
+            {
+                "client_secret": "foo",
+                "email": "bar@test.com",
+                "send_attempt": 0,
+            },
+            access_token=tok,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+        self.assertIn("sid", channel.json_body)
+
+        m.assert_called_once_with("email", "bar@test.com", registration)
+
     def _setup_get_username_for_registration(self) -> Mock:
         """Registers a get_username_for_registration callback that appends "-foo" to the
         username the client is trying to register.