summary refs log tree commit diff
path: root/synapse/storage/databases/main/registration.py
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2022-09-29 14:23:24 +0100
committerGitHub <noreply@github.com>2022-09-29 15:23:24 +0200
commitbe76cd8200b18f3c68b895f85ac7ef5b0ddc2466 (patch)
treec32c1b1e7a835b9f970fc4ae28ded0dd858ff841 /synapse/storage/databases/main/registration.py
parentExplicit cast to enforce type hints. (#13939) (diff)
downloadsynapse-be76cd8200b18f3c68b895f85ac7ef5b0ddc2466.tar.xz
Allow admins to require a manual approval process before new accounts can be used (using MSC3866) (#13556)
Diffstat (limited to 'synapse/storage/databases/main/registration.py')
-rw-r--r--synapse/storage/databases/main/registration.py150
1 files changed, 132 insertions, 18 deletions
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index ac821878b0..2996d6bb4d 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -166,27 +166,49 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     @cached()
     async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
         """Deprecated: use get_userinfo_by_id instead"""
-        return await self.db_pool.simple_select_one(
-            table="users",
-            keyvalues={"name": user_id},
-            retcols=[
-                "name",
-                "password_hash",
-                "is_guest",
-                "admin",
-                "consent_version",
-                "consent_ts",
-                "consent_server_notice_sent",
-                "appservice_id",
-                "creation_ts",
-                "user_type",
-                "deactivated",
-                "shadow_banned",
-            ],
-            allow_none=True,
+
+        def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
+            # We could technically use simple_select_one here, but it would not perform
+            # the COALESCEs (unless hacked into the column names), which could yield
+            # confusing results.
+            txn.execute(
+                """
+                SELECT
+                    name, password_hash, is_guest, admin, consent_version, consent_ts,
+                    consent_server_notice_sent, appservice_id, creation_ts, user_type,
+                    deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
+                    COALESCE(approved, TRUE) AS approved
+                FROM users
+                WHERE name = ?
+                """,
+                (user_id,),
+            )
+
+            rows = self.db_pool.cursor_to_dict(txn)
+
+            if len(rows) == 0:
+                return None
+
+            return rows[0]
+
+        row = await self.db_pool.runInteraction(
             desc="get_user_by_id",
+            func=get_user_by_id_txn,
         )
 
+        if row is not None:
+            # If we're using SQLite our boolean values will be integers. Because we
+            # present some of this data as is to e.g. server admins via REST APIs, we
+            # want to make sure we're returning the right type of data.
+            # Note: when adding a column name to this list, be wary of NULLable columns,
+            # since NULL values will be turned into False.
+            boolean_columns = ["admin", "deactivated", "shadow_banned", "approved"]
+            for column in boolean_columns:
+                if not isinstance(row[column], bool):
+                    row[column] = bool(row[column])
+
+        return row
+
     async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
         """Get a UserInfo object for a user by user ID.
 
@@ -1779,6 +1801,40 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         return res if res else False
 
+    @cached()
+    async def is_user_approved(self, user_id: str) -> bool:
+        """Checks if a user is approved and therefore can be allowed to log in.
+
+        If the user's 'approved' column is NULL, we consider it as true given it means
+        the user was registered when support for an approval flow was either disabled
+        or nonexistent.
+
+        Args:
+            user_id: the user to check the approval status of.
+
+        Returns:
+            A boolean that is True if the user is approved, False otherwise.
+        """
+
+        def is_user_approved_txn(txn: LoggingTransaction) -> bool:
+            txn.execute(
+                """
+                SELECT COALESCE(approved, TRUE) AS approved FROM users WHERE name = ?
+                """,
+                (user_id,),
+            )
+
+            rows = self.db_pool.cursor_to_dict(txn)
+
+            # We cast to bool because the value returned by the database engine might
+            # be an integer if we're using SQLite.
+            return bool(rows[0]["approved"])
+
+        return await self.db_pool.runInteraction(
+            desc="is_user_pending_approval",
+            func=is_user_approved_txn,
+        )
+
 
 class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
     def __init__(
@@ -1916,6 +1972,29 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
         self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
         txn.call_after(self.is_guest.invalidate, (user_id,))
 
+    def update_user_approval_status_txn(
+        self, txn: LoggingTransaction, user_id: str, approved: bool
+    ) -> None:
+        """Set the user's 'approved' flag to the given value.
+
+        The boolean is turned into an int because the column is a smallint.
+
+        Args:
+            txn: the current database transaction.
+            user_id: the user to update the flag for.
+            approved: the value to set the flag to.
+        """
+        self.db_pool.simple_update_one_txn(
+            txn=txn,
+            table="users",
+            keyvalues={"name": user_id},
+            updatevalues={"approved": approved},
+        )
+
+        # Invalidate the caches of methods that read the value of the 'approved' flag.
+        self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+        self._invalidate_cache_and_stream(txn, self.is_user_approved, (user_id,))
+
 
 class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
     def __init__(
@@ -1933,6 +2012,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
         self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
         self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
 
+        # If support for MSC3866 is enabled and configured to require approval for new
+        # account, we will create new users with an 'approved' flag set to false.
+        self._require_approval = (
+            hs.config.experimental.msc3866.enabled
+            and hs.config.experimental.msc3866.require_approval_for_new_accounts
+        )
+
     async def add_access_token_to_user(
         self,
         user_id: str,
@@ -2065,6 +2151,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
         admin: bool = False,
         user_type: Optional[str] = None,
         shadow_banned: bool = False,
+        approved: bool = False,
     ) -> None:
         """Attempts to register an account.
 
@@ -2083,6 +2170,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
                 or None for a normal user.
             shadow_banned: Whether the user is shadow-banned, i.e. they may be
                 told their requests succeeded but we ignore them.
+            approved: Whether to consider the user has already been approved by an
+                administrator.
 
         Raises:
             StoreError if the user_id could not be registered.
@@ -2099,6 +2188,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
             admin,
             user_type,
             shadow_banned,
+            approved,
         )
 
     def _register_user(
@@ -2113,11 +2203,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
         admin: bool,
         user_type: Optional[str],
         shadow_banned: bool,
+        approved: bool,
     ) -> None:
         user_id_obj = UserID.from_string(user_id)
 
         now = int(self._clock.time())
 
+        user_approved = approved or not self._require_approval
+
         try:
             if was_guest:
                 # Ensure that the guest user actually exists
@@ -2143,6 +2236,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
                         "admin": 1 if admin else 0,
                         "user_type": user_type,
                         "shadow_banned": shadow_banned,
+                        "approved": user_approved,
                     },
                 )
             else:
@@ -2158,6 +2252,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
                         "admin": 1 if admin else 0,
                         "user_type": user_type,
                         "shadow_banned": shadow_banned,
+                        "approved": user_approved,
                     },
                 )
 
@@ -2503,6 +2598,25 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
             start_or_continue_validation_session_txn,
         )
 
+    async def update_user_approval_status(
+        self, user_id: UserID, approved: bool
+    ) -> None:
+        """Set the user's 'approved' flag to the given value.
+
+        The boolean will be turned into an int (in update_user_approval_status_txn)
+        because the column is a smallint.
+
+        Args:
+            user_id: the user to update the flag for.
+            approved: the value to set the flag to.
+        """
+        await self.db_pool.runInteraction(
+            "update_user_approval_status",
+            self.update_user_approval_status_txn,
+            user_id.to_string(),
+            approved,
+        )
+
 
 def find_max_generated_user_id_localpart(cur: Cursor) -> int:
     """