summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py44
1 files changed, 44 insertions, 0 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index e32c93e234..6959d1aa7e 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -2064,6 +2064,7 @@ GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
     [JsonDict, JsonDict],
     Awaitable[Optional[str]],
 ]
+IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
 
 
 class PasswordAuthProvider:
@@ -2079,6 +2080,7 @@ class PasswordAuthProvider:
         self.get_username_for_registration_callbacks: List[
             GET_USERNAME_FOR_REGISTRATION_CALLBACK
         ] = []
+        self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
 
         # Mapping from login type to login parameters
         self._supported_login_types: Dict[str, Iterable[str]] = {}
@@ -2090,6 +2092,7 @@ class PasswordAuthProvider:
         self,
         check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
         on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
+        is_3pid_allowed: Optional[IS_3PID_ALLOWED_CALLBACK] = None,
         auth_checkers: Optional[
             Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
         ] = None,
@@ -2145,6 +2148,9 @@ class PasswordAuthProvider:
                 get_username_for_registration,
             )
 
+        if is_3pid_allowed is not None:
+            self.is_3pid_allowed_callbacks.append(is_3pid_allowed)
+
     def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
         """Get the login types supported by this password provider
 
@@ -2343,3 +2349,41 @@ class PasswordAuthProvider:
                 raise SynapseError(code=500, msg="Internal Server Error")
 
         return None
+
+    async def is_3pid_allowed(
+        self,
+        medium: str,
+        address: str,
+        registration: bool,
+    ) -> bool:
+        """Check if the user can be allowed to bind a 3PID on this homeserver.
+
+        Args:
+            medium: The medium of the 3PID.
+            address: The address of the 3PID.
+            registration: Whether the 3PID is being bound when registering a new user.
+
+        Returns:
+            Whether the 3PID is allowed to be bound on this homeserver
+        """
+        for callback in self.is_3pid_allowed_callbacks:
+            try:
+                res = await callback(medium, address, registration)
+
+                if res is False:
+                    return res
+                elif not isinstance(res, bool):
+                    # mypy complains that this line is unreachable because it assumes the
+                    # data returned by the module fits the expected type. We just want
+                    # to make sure this is the case.
+                    logger.warning(  # type: ignore[unreachable]
+                        "Ignoring non-string value returned by"
+                        " is_3pid_allowed callback %s: %s",
+                        callback,
+                        res,
+                    )
+            except Exception as e:
+                logger.error("Module raised an exception in is_3pid_allowed: %s", e)
+                raise SynapseError(code=500, msg="Internal Server Error")
+
+        return True