diff options
Diffstat (limited to 'synapse/storage/databases/main/registration.py')
-rw-r--r-- | synapse/storage/databases/main/registration.py | 62 |
1 files changed, 58 insertions, 4 deletions
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index c582cf0573..d3a01d526f 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -205,7 +205,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): name, password_hash, is_guest, admin, consent_version, consent_ts, consent_server_notice_sent, appservice_id, creation_ts, user_type, deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned, - COALESCE(approved, TRUE) AS approved + COALESCE(approved, TRUE) AS approved, + COALESCE(locked, FALSE) AS locked FROM users WHERE name = ? """, @@ -230,10 +231,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): # want to make sure we're returning the right type of data. # Note: when adding a column name to this list, be wary of NULLable columns, # since NULL values will be turned into False. - boolean_columns = ["admin", "deactivated", "shadow_banned", "approved"] + boolean_columns = [ + "admin", + "deactivated", + "shadow_banned", + "approved", + "locked", + ] for column in boolean_columns: - if not isinstance(row[column], bool): - row[column] = bool(row[column]) + row[column] = bool(row[column]) return row @@ -1116,6 +1122,27 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): # Convert the integer into a boolean. return res == 1 + @cached() + async def get_user_locked_status(self, user_id: str) -> bool: + """Retrieve the value for the `locked` property for the provided user. + + Args: + user_id: The ID of the user to retrieve the status for. + + Returns: + True if the user was locked, false if the user is still active. + """ + + res = await self.db_pool.simple_select_one_onecol( + table="users", + keyvalues={"name": user_id}, + retcol="locked", + desc="get_user_locked_status", + ) + + # Convert the potential integer into a boolean. + return bool(res) + async def get_threepid_validation_session( self, medium: Optional[str], @@ -2111,6 +2138,33 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,)) + async def set_user_locked_status(self, user_id: str, locked: bool) -> None: + """Set the `locked` property for the provided user to the provided value. + + Args: + user_id: The ID of the user to set the status for. + locked: The value to set for `locked`. + """ + + await self.db_pool.runInteraction( + "set_user_locked_status", + self.set_user_locked_status_txn, + user_id, + locked, + ) + + def set_user_locked_status_txn( + self, txn: LoggingTransaction, user_id: str, locked: bool + ) -> None: + self.db_pool.simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"locked": locked}, + ) + self._invalidate_cache_and_stream(txn, self.get_user_locked_status, (user_id,)) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + def update_user_approval_status_txn( self, txn: LoggingTransaction, user_id: str, approved: bool ) -> None: |