diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index e32c93e234..6959d1aa7e 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -2064,6 +2064,7 @@ GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
[JsonDict, JsonDict],
Awaitable[Optional[str]],
]
+IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
class PasswordAuthProvider:
@@ -2079,6 +2080,7 @@ class PasswordAuthProvider:
self.get_username_for_registration_callbacks: List[
GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = []
+ self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
# Mapping from login type to login parameters
self._supported_login_types: Dict[str, Iterable[str]] = {}
@@ -2090,6 +2092,7 @@ class PasswordAuthProvider:
self,
check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
+ is_3pid_allowed: Optional[IS_3PID_ALLOWED_CALLBACK] = None,
auth_checkers: Optional[
Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
] = None,
@@ -2145,6 +2148,9 @@ class PasswordAuthProvider:
get_username_for_registration,
)
+ if is_3pid_allowed is not None:
+ self.is_3pid_allowed_callbacks.append(is_3pid_allowed)
+
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
"""Get the login types supported by this password provider
@@ -2343,3 +2349,41 @@ class PasswordAuthProvider:
raise SynapseError(code=500, msg="Internal Server Error")
return None
+
+ async def is_3pid_allowed(
+ self,
+ medium: str,
+ address: str,
+ registration: bool,
+ ) -> bool:
+ """Check if the user can be allowed to bind a 3PID on this homeserver.
+
+ Args:
+ medium: The medium of the 3PID.
+ address: The address of the 3PID.
+ registration: Whether the 3PID is being bound when registering a new user.
+
+ Returns:
+ Whether the 3PID is allowed to be bound on this homeserver
+ """
+ for callback in self.is_3pid_allowed_callbacks:
+ try:
+ res = await callback(medium, address, registration)
+
+ if res is False:
+ return res
+ elif not isinstance(res, bool):
+ # mypy complains that this line is unreachable because it assumes the
+ # data returned by the module fits the expected type. We just want
+ # to make sure this is the case.
+ logger.warning( # type: ignore[unreachable]
+ "Ignoring non-string value returned by"
+ " is_3pid_allowed callback %s: %s",
+ callback,
+ res,
+ )
+ except Exception as e:
+ logger.error("Module raised an exception in is_3pid_allowed: %s", e)
+ raise SynapseError(code=500, msg="Internal Server Error")
+
+ return True
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 29fbc73c97..a91a7fa3ce 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -72,6 +72,7 @@ from synapse.handlers.auth import (
CHECK_3PID_AUTH_CALLBACK,
CHECK_AUTH_CALLBACK,
GET_USERNAME_FOR_REGISTRATION_CALLBACK,
+ IS_3PID_ALLOWED_CALLBACK,
ON_LOGGED_OUT_CALLBACK,
AuthHandler,
)
@@ -312,6 +313,7 @@ class ModuleApi:
auth_checkers: Optional[
Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
] = None,
+ is_3pid_allowed: Optional[IS_3PID_ALLOWED_CALLBACK] = None,
get_username_for_registration: Optional[
GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = None,
@@ -323,6 +325,7 @@ class ModuleApi:
return self._password_auth_provider.register_password_auth_provider_callbacks(
check_3pid_auth=check_3pid_auth,
on_logged_out=on_logged_out,
+ is_3pid_allowed=is_3pid_allowed,
auth_checkers=auth_checkers,
get_username_for_registration=get_username_for_registration,
)
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index 6b272658fc..cfa2aee76d 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -385,7 +385,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
- if not check_3pid_allowed(self.hs, "email", email):
+ if not await check_3pid_allowed(self.hs, "email", email):
raise SynapseError(
403,
"Your email domain is not authorized on this server",
@@ -468,7 +468,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(country, phone_number)
- if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ if not await check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
403,
"Account phone numbers are not authorized on this server",
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index c283313e8d..c965e2bda2 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -112,7 +112,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
- if not check_3pid_allowed(self.hs, "email", email):
+ if not await check_3pid_allowed(self.hs, "email", email, registration=True):
raise SynapseError(
403,
"Your email domain is not authorized to register on this server",
@@ -192,7 +192,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(country, phone_number)
- if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ if not await check_3pid_allowed(self.hs, "msisdn", msisdn, registration=True):
raise SynapseError(
403,
"Phone numbers are not authorized to register on this server",
@@ -616,7 +616,9 @@ class RegisterRestServlet(RestServlet):
medium = auth_result[login_type]["medium"]
address = auth_result[login_type]["address"]
- if not check_3pid_allowed(self.hs, medium, address):
+ if not await check_3pid_allowed(
+ self.hs, medium, address, registration=True
+ ):
raise SynapseError(
403,
"Third party identifiers (email/phone numbers)"
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 389adf00f6..1e9c2faa64 100644
--- a/synapse/util/threepids.py
+++ b/synapse/util/threepids.py
@@ -32,7 +32,12 @@ logger = logging.getLogger(__name__)
MAX_EMAIL_ADDRESS_LENGTH = 500
-def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool:
+async def check_3pid_allowed(
+ hs: "HomeServer",
+ medium: str,
+ address: str,
+ registration: bool = False,
+) -> bool:
"""Checks whether a given format of 3PID is allowed to be used on this HS
Args:
@@ -40,9 +45,15 @@ def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool:
medium: 3pid medium - e.g. email, msisdn
address: address within that medium (e.g. "wotan@matrix.org")
msisdns need to first have been canonicalised
+ registration: whether we want to bind the 3PID as part of registering a new user.
+
Returns:
bool: whether the 3PID medium/address is allowed to be added to this HS
"""
+ if not await hs.get_password_auth_provider().is_3pid_allowed(
+ medium, address, registration
+ ):
+ return False
if hs.config.registration.allowed_local_3pids:
for constraint in hs.config.registration.allowed_local_3pids:
|