diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 0843f10340..a62b4abd4e 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -203,6 +203,7 @@ class DataStore(
deactivated: bool = False,
order_by: str = UserSortOrder.USER_ID.value,
direction: str = "f",
+ approved: bool = True,
) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users and the
@@ -217,6 +218,7 @@ class DataStore(
deactivated: whether to include deactivated users
order_by: the sort order of the returned list
direction: sort ascending or descending
+ approved: whether to include approved users
Returns:
A tuple of a list of mappings from user to information and a count of total users.
"""
@@ -249,6 +251,11 @@ class DataStore(
if not deactivated:
filters.append("deactivated = 0")
+ if not approved:
+ # We ignore NULL values for the approved flag because these should only
+ # be already existing users that we consider as already approved.
+ filters.append("approved IS FALSE")
+
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
sql_base = f"""
@@ -262,7 +269,7 @@ class DataStore(
sql = f"""
SELECT name, user_type, is_guest, admin, deactivated, shadow_banned,
- displayname, avatar_url, creation_ts * 1000 as creation_ts
+ displayname, avatar_url, creation_ts * 1000 as creation_ts, approved
{sql_base}
ORDER BY {order_by_column} {order}, u.name ASC
LIMIT ? OFFSET ?
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index ac821878b0..2996d6bb4d 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -166,27 +166,49 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
@cached()
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Deprecated: use get_userinfo_by_id instead"""
- return await self.db_pool.simple_select_one(
- table="users",
- keyvalues={"name": user_id},
- retcols=[
- "name",
- "password_hash",
- "is_guest",
- "admin",
- "consent_version",
- "consent_ts",
- "consent_server_notice_sent",
- "appservice_id",
- "creation_ts",
- "user_type",
- "deactivated",
- "shadow_banned",
- ],
- allow_none=True,
+
+ def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
+ # We could technically use simple_select_one here, but it would not perform
+ # the COALESCEs (unless hacked into the column names), which could yield
+ # confusing results.
+ txn.execute(
+ """
+ SELECT
+ name, password_hash, is_guest, admin, consent_version, consent_ts,
+ consent_server_notice_sent, appservice_id, creation_ts, user_type,
+ deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
+ COALESCE(approved, TRUE) AS approved
+ FROM users
+ WHERE name = ?
+ """,
+ (user_id,),
+ )
+
+ rows = self.db_pool.cursor_to_dict(txn)
+
+ if len(rows) == 0:
+ return None
+
+ return rows[0]
+
+ row = await self.db_pool.runInteraction(
desc="get_user_by_id",
+ func=get_user_by_id_txn,
)
+ if row is not None:
+ # If we're using SQLite our boolean values will be integers. Because we
+ # present some of this data as is to e.g. server admins via REST APIs, we
+ # want to make sure we're returning the right type of data.
+ # Note: when adding a column name to this list, be wary of NULLable columns,
+ # since NULL values will be turned into False.
+ boolean_columns = ["admin", "deactivated", "shadow_banned", "approved"]
+ for column in boolean_columns:
+ if not isinstance(row[column], bool):
+ row[column] = bool(row[column])
+
+ return row
+
async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Get a UserInfo object for a user by user ID.
@@ -1779,6 +1801,40 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return res if res else False
+ @cached()
+ async def is_user_approved(self, user_id: str) -> bool:
+ """Checks if a user is approved and therefore can be allowed to log in.
+
+ If the user's 'approved' column is NULL, we consider it as true given it means
+ the user was registered when support for an approval flow was either disabled
+ or nonexistent.
+
+ Args:
+ user_id: the user to check the approval status of.
+
+ Returns:
+ A boolean that is True if the user is approved, False otherwise.
+ """
+
+ def is_user_approved_txn(txn: LoggingTransaction) -> bool:
+ txn.execute(
+ """
+ SELECT COALESCE(approved, TRUE) AS approved FROM users WHERE name = ?
+ """,
+ (user_id,),
+ )
+
+ rows = self.db_pool.cursor_to_dict(txn)
+
+ # We cast to bool because the value returned by the database engine might
+ # be an integer if we're using SQLite.
+ return bool(rows[0]["approved"])
+
+ return await self.db_pool.runInteraction(
+ desc="is_user_pending_approval",
+ func=is_user_approved_txn,
+ )
+
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(
@@ -1916,6 +1972,29 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
+ def update_user_approval_status_txn(
+ self, txn: LoggingTransaction, user_id: str, approved: bool
+ ) -> None:
+ """Set the user's 'approved' flag to the given value.
+
+ The boolean is turned into an int because the column is a smallint.
+
+ Args:
+ txn: the current database transaction.
+ user_id: the user to update the flag for.
+ approved: the value to set the flag to.
+ """
+ self.db_pool.simple_update_one_txn(
+ txn=txn,
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"approved": approved},
+ )
+
+ # Invalidate the caches of methods that read the value of the 'approved' flag.
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+ self._invalidate_cache_and_stream(txn, self.is_user_approved, (user_id,))
+
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def __init__(
@@ -1933,6 +2012,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
+ # If support for MSC3866 is enabled and configured to require approval for new
+ # account, we will create new users with an 'approved' flag set to false.
+ self._require_approval = (
+ hs.config.experimental.msc3866.enabled
+ and hs.config.experimental.msc3866.require_approval_for_new_accounts
+ )
+
async def add_access_token_to_user(
self,
user_id: str,
@@ -2065,6 +2151,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
admin: bool = False,
user_type: Optional[str] = None,
shadow_banned: bool = False,
+ approved: bool = False,
) -> None:
"""Attempts to register an account.
@@ -2083,6 +2170,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
or None for a normal user.
shadow_banned: Whether the user is shadow-banned, i.e. they may be
told their requests succeeded but we ignore them.
+ approved: Whether to consider the user has already been approved by an
+ administrator.
Raises:
StoreError if the user_id could not be registered.
@@ -2099,6 +2188,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
admin,
user_type,
shadow_banned,
+ approved,
)
def _register_user(
@@ -2113,11 +2203,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
admin: bool,
user_type: Optional[str],
shadow_banned: bool,
+ approved: bool,
) -> None:
user_id_obj = UserID.from_string(user_id)
now = int(self._clock.time())
+ user_approved = approved or not self._require_approval
+
try:
if was_guest:
# Ensure that the guest user actually exists
@@ -2143,6 +2236,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"admin": 1 if admin else 0,
"user_type": user_type,
"shadow_banned": shadow_banned,
+ "approved": user_approved,
},
)
else:
@@ -2158,6 +2252,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"admin": 1 if admin else 0,
"user_type": user_type,
"shadow_banned": shadow_banned,
+ "approved": user_approved,
},
)
@@ -2503,6 +2598,25 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
start_or_continue_validation_session_txn,
)
+ async def update_user_approval_status(
+ self, user_id: UserID, approved: bool
+ ) -> None:
+ """Set the user's 'approved' flag to the given value.
+
+ The boolean will be turned into an int (in update_user_approval_status_txn)
+ because the column is a smallint.
+
+ Args:
+ user_id: the user to update the flag for.
+ approved: the value to set the flag to.
+ """
+ await self.db_pool.runInteraction(
+ "update_user_approval_status",
+ self.update_user_approval_status_txn,
+ user_id.to_string(),
+ approved,
+ )
+
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
"""
diff --git a/synapse/storage/schema/main/delta/73/03users_approved_column.sql b/synapse/storage/schema/main/delta/73/03users_approved_column.sql
new file mode 100644
index 0000000000..5328d592ea
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/03users_approved_column.sql
@@ -0,0 +1,20 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Add a column to the users table to track whether the user needs to be approved by an
+-- administrator.
+-- A NULL column means the user was created before this feature was supported by Synapse,
+-- and should be considered as TRUE.
+ALTER TABLE users ADD COLUMN approved BOOLEAN;
|