diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 1dd1182e82..983ce13291 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
import re
from six import iterkeys
@@ -31,6 +32,8 @@ from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
+logger = logging.getLogger(__name__)
+
class RegistrationWorkerStore(SQLBaseStore):
def __init__(self, db_conn, hs):
@@ -113,8 +116,9 @@ class RegistrationWorkerStore(SQLBaseStore):
defer.returnValue(res)
@defer.inlineCallbacks
- def set_account_validity_for_user(self, user_id, expiration_ts, email_sent,
- renewal_token=None):
+ def set_account_validity_for_user(
+ self, user_id, expiration_ts, email_sent, renewal_token=None
+ ):
"""Updates the account validity properties of the given account, with the
given values.
@@ -128,6 +132,7 @@ class RegistrationWorkerStore(SQLBaseStore):
renewal_token (str): Renewal token the user can use to extend the validity
of their account. Defaults to no token.
"""
+
def set_account_validity_for_user_txn(txn):
self._simple_update_txn(
txn=txn,
@@ -140,12 +145,11 @@ class RegistrationWorkerStore(SQLBaseStore):
},
)
self._invalidate_cache_and_stream(
- txn, self.get_expiration_ts_for_user, (user_id,),
+ txn, self.get_expiration_ts_for_user, (user_id,)
)
yield self.runInteraction(
- "set_account_validity_for_user",
- set_account_validity_for_user_txn,
+ "set_account_validity_for_user", set_account_validity_for_user_txn
)
@defer.inlineCallbacks
@@ -214,6 +218,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
"""
+
def select_users_txn(txn, now_ms, renew_at):
sql = (
"SELECT user_id, expiration_ts_ms FROM account_validity"
@@ -226,7 +231,8 @@ class RegistrationWorkerStore(SQLBaseStore):
res = yield self.runInteraction(
"get_users_expiring_soon",
select_users_txn,
- self.clock.time_msec(), self.config.account_validity.renew_at,
+ self.clock.time_msec(),
+ self.config.account_validity.renew_at,
)
defer.returnValue(res)
@@ -249,6 +255,20 @@ class RegistrationWorkerStore(SQLBaseStore):
)
@defer.inlineCallbacks
+ def delete_account_validity_for_user(self, user_id):
+ """Deletes the entry for the given user in the account validity table, removing
+ their expiration date and renewal token.
+
+ Args:
+ user_id (str): ID of the user to remove from the account validity table.
+ """
+ yield self._simple_delete_one(
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ desc="delete_account_validity_for_user",
+ )
+
+ @defer.inlineCallbacks
def is_server_admin(self, user):
res = yield self._simple_select_one_onecol(
table="users",
@@ -352,7 +372,7 @@ class RegistrationWorkerStore(SQLBaseStore):
WHERE creation_ts > ?
) AS t GROUP BY user_type
"""
- results = {'native': 0, 'guest': 0, 'bridged': 0}
+ results = {"native": 0, "guest": 0, "bridged": 0}
txn.execute(sql, (yesterday,))
for row in txn:
results[row[0]] = row[1]
@@ -418,7 +438,7 @@ class RegistrationWorkerStore(SQLBaseStore):
{"medium": medium, "address": address},
["guest_access_token"],
True,
- 'get_3pid_guest_access_token',
+ "get_3pid_guest_access_token",
)
if ret:
defer.returnValue(ret["guest_access_token"])
@@ -455,11 +475,11 @@ class RegistrationWorkerStore(SQLBaseStore):
txn,
"user_threepids",
{"medium": medium, "address": address},
- ['user_id'],
+ ["user_id"],
True,
)
if ret:
- return ret['user_id']
+ return ret["user_id"]
return None
@defer.inlineCallbacks
@@ -475,8 +495,8 @@ class RegistrationWorkerStore(SQLBaseStore):
ret = yield self._simple_select_list(
"user_threepids",
{"user_id": user_id},
- ['medium', 'address', 'validated_at', 'added_at'],
- 'user_get_threepids',
+ ["medium", "address", "validated_at", "added_at"],
+ "user_get_threepids",
)
defer.returnValue(ret)
@@ -555,11 +575,7 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
return self._simple_select_onecol(
table="user_threepid_id_server",
- keyvalues={
- "user_id": user_id,
- "medium": medium,
- "address": address,
- },
+ keyvalues={"user_id": user_id, "medium": medium, "address": address},
retcol="id_server",
desc="get_id_servers_user_bound",
)
@@ -595,15 +611,80 @@ class RegistrationStore(
self.register_noop_background_update("refresh_tokens_device_index")
self.register_background_update_handler(
- "user_threepids_grandfather", self._bg_user_threepids_grandfather,
+ "user_threepids_grandfather", self._bg_user_threepids_grandfather
+ )
+
+ self.register_background_update_handler(
+ "users_set_deactivated_flag", self._backgroud_update_set_deactivated_flag
)
# Create a background job for culling expired 3PID validity tokens
hs.get_clock().looping_call(
- self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS,
+ self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
)
@defer.inlineCallbacks
+ def _backgroud_update_set_deactivated_flag(self, progress, batch_size):
+ """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
+ for each of them.
+ """
+
+ last_user = progress.get("user_id", "")
+
+ def _backgroud_update_set_deactivated_flag_txn(txn):
+ txn.execute(
+ """
+ SELECT
+ users.name,
+ COUNT(access_tokens.token) AS count_tokens,
+ COUNT(user_threepids.address) AS count_threepids
+ FROM users
+ LEFT JOIN access_tokens ON (access_tokens.user_id = users.name)
+ LEFT JOIN user_threepids ON (user_threepids.user_id = users.name)
+ WHERE (users.password_hash IS NULL OR users.password_hash = '')
+ AND (users.appservice_id IS NULL OR users.appservice_id = '')
+ AND users.is_guest = 0
+ AND users.name > ?
+ GROUP BY users.name
+ ORDER BY users.name ASC
+ LIMIT ?;
+ """,
+ (last_user, batch_size),
+ )
+
+ rows = self.cursor_to_dict(txn)
+
+ if not rows:
+ return True
+
+ rows_processed_nb = 0
+
+ for user in rows:
+ if not user["count_tokens"] and not user["count_threepids"]:
+ self.set_user_deactivated_status_txn(txn, user["name"], True)
+ rows_processed_nb += 1
+
+ logger.info("Marked %d rows as deactivated", rows_processed_nb)
+
+ self._background_update_progress_txn(
+ txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
+ )
+
+ if batch_size > len(rows):
+ return True
+ else:
+ return False
+
+ end = yield self.runInteraction(
+ "users_set_deactivated_flag", _backgroud_update_set_deactivated_flag_txn
+ )
+
+ if end:
+ yield self._end_background_update("users_set_deactivated_flag")
+
+ defer.returnValue(batch_size)
+
+ @defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id=None):
"""Adds an access token for the given user.
@@ -768,7 +849,7 @@ class RegistrationStore(
def user_set_password_hash_txn(txn):
self._simple_update_one_txn(
- txn, 'users', {'name': user_id}, {'password_hash': password_hash}
+ txn, "users", {"name": user_id}, {"password_hash": password_hash}
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
@@ -789,9 +870,9 @@ class RegistrationStore(
def f(txn):
self._simple_update_one_txn(
txn,
- table='users',
- keyvalues={'name': user_id},
- updatevalues={'consent_version': consent_version},
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"consent_version": consent_version},
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
@@ -813,9 +894,9 @@ class RegistrationStore(
def f(txn):
self._simple_update_one_txn(
txn,
- table='users',
- keyvalues={'name': user_id},
- updatevalues={'consent_server_notice_sent': consent_version},
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"consent_server_notice_sent": consent_version},
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
@@ -985,7 +1066,7 @@ class RegistrationStore(
if id_servers:
yield self.runInteraction(
- "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn,
+ "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
)
yield self._end_background_update("user_threepids_grandfather")
@@ -993,12 +1074,7 @@ class RegistrationStore(
defer.returnValue(1)
def get_threepid_validation_session(
- self,
- medium,
- client_secret,
- address=None,
- sid=None,
- validated=True,
+ self, medium, client_secret, address=None, sid=None, validated=True
):
"""Gets a session_id and last_send_attempt (if available) for a
client_secret/medium/(address|session_id) combo
@@ -1018,23 +1094,22 @@ class RegistrationStore(
latest session_id and send_attempt count for this 3PID.
Otherwise None if there hasn't been a previous attempt
"""
- keyvalues = {
- "medium": medium,
- "client_secret": client_secret,
- }
+ keyvalues = {"medium": medium, "client_secret": client_secret}
if address:
keyvalues["address"] = address
if sid:
keyvalues["session_id"] = sid
- assert(address or sid)
+ assert address or sid
def get_threepid_validation_session_txn(txn):
sql = """
SELECT address, session_id, medium, client_secret,
last_send_attempt, validated_at
FROM threepid_validation_session WHERE %s
- """ % (" AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),)
+ """ % (
+ " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),
+ )
if validated is not None:
sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL")
@@ -1049,17 +1124,10 @@ class RegistrationStore(
return rows[0]
return self.runInteraction(
- "get_threepid_validation_session",
- get_threepid_validation_session_txn,
+ "get_threepid_validation_session", get_threepid_validation_session_txn
)
- def validate_threepid_session(
- self,
- session_id,
- client_secret,
- token,
- current_ts,
- ):
+ def validate_threepid_session(self, session_id, client_secret, token, current_ts):
"""Attempt to validate a threepid session using a token
Args:
@@ -1091,7 +1159,7 @@ class RegistrationStore(
if retrieved_client_secret != client_secret:
raise ThreepidValidationError(
- 400, "This client_secret does not match the provided session_id",
+ 400, "This client_secret does not match the provided session_id"
)
row = self._simple_select_one_txn(
@@ -1104,7 +1172,7 @@ class RegistrationStore(
if not row:
raise ThreepidValidationError(
- 400, "Validation token not found or has expired",
+ 400, "Validation token not found or has expired"
)
expires = row["expires"]
next_link = row["next_link"]
@@ -1115,7 +1183,7 @@ class RegistrationStore(
if expires <= current_ts:
raise ThreepidValidationError(
- 400, "This token has expired. Please request a new one",
+ 400, "This token has expired. Please request a new one"
)
# Looks good. Validate the session
@@ -1130,8 +1198,7 @@ class RegistrationStore(
# Return next_link if it exists
return self.runInteraction(
- "validate_threepid_session_txn",
- validate_threepid_session_txn,
+ "validate_threepid_session_txn", validate_threepid_session_txn
)
def upsert_threepid_validation_session(
@@ -1198,6 +1265,7 @@ class RegistrationStore(
token_expires (int): The timestamp for which after the token
will no longer be valid
"""
+
def start_or_continue_validation_session_txn(txn):
# Create or update a validation session
self._simple_upsert_txn(
@@ -1231,6 +1299,7 @@ class RegistrationStore(
def cull_expired_threepid_validation_tokens(self):
"""Remove threepid validation tokens with expiry dates that have passed"""
+
def cull_expired_threepid_validation_tokens_txn(txn, ts):
sql = """
DELETE FROM threepid_validation_token WHERE
@@ -1252,6 +1321,7 @@ class RegistrationStore(
Args:
session_id (str): The ID of the session to delete
"""
+
def delete_threepid_session_txn(txn):
self._simple_delete_txn(
txn,
@@ -1265,6 +1335,53 @@ class RegistrationStore(
)
return self.runInteraction(
- "delete_threepid_session",
- delete_threepid_session_txn,
+ "delete_threepid_session", delete_threepid_session_txn
+ )
+
+ def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
+ self._simple_update_one_txn(
+ txn=txn,
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"deactivated": 1 if deactivated else 0},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_deactivated_status, (user_id,)
)
+
+ @defer.inlineCallbacks
+ def set_user_deactivated_status(self, user_id, deactivated):
+ """Set the `deactivated` property for the provided user to the provided value.
+
+ Args:
+ user_id (str): The ID of the user to set the status for.
+ deactivated (bool): The value to set for `deactivated`.
+ """
+
+ yield self.runInteraction(
+ "set_user_deactivated_status",
+ self.set_user_deactivated_status_txn,
+ user_id,
+ deactivated,
+ )
+
+ @cachedInlineCallbacks()
+ def get_user_deactivated_status(self, user_id):
+ """Retrieve the value for the `deactivated` property for the provided user.
+
+ Args:
+ user_id (str): The ID of the user to retrieve the status for.
+
+ Returns:
+ defer.Deferred(bool): The requested value.
+ """
+
+ res = yield self._simple_select_one_onecol(
+ table="users",
+ keyvalues={"name": user_id},
+ retcol="deactivated",
+ desc="get_user_deactivated_status",
+ )
+
+ # Convert the integer into a boolean.
+ defer.returnValue(res == 1)
|