summary refs log tree commit diff
path: root/synapse/storage/databases/main/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/registration.py')
-rw-r--r--synapse/storage/databases/main/registration.py99
1 files changed, 98 insertions, 1 deletions
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py

index f618629e09..35acf48757 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -17,7 +17,7 @@ import logging import re -from typing import Optional +from typing import List, Optional from twisted.internet import defer from twisted.internet.defer import Deferred @@ -126,6 +126,23 @@ class RegistrationWorkerStore(SQLBaseStore): return res @defer.inlineCallbacks + def is_account_expired(self, user_id: str, current_ts: int): + """ + Returns whether an user account is expired. + + Args: + user_id: The user's ID + current_ts: The current timestamp + + Returns: + Deferred[bool]: whether the user account has expired + """ + expiration_ts = yield self.get_expiration_ts_for_user(user_id) + if expiration_ts is not None and current_ts >= expiration_ts: + return True + return False + + @defer.inlineCallbacks def set_account_validity_for_user( self, user_id, expiration_ts, email_sent, renewal_token=None ): @@ -163,6 +180,37 @@ class RegistrationWorkerStore(SQLBaseStore): ) @defer.inlineCallbacks + def get_expired_users(self): + """Get UserIDs of all expired users. + + Users who are not active, or do not have profile information, are + excluded from the results. + + Returns: + Deferred[List[UserID]]: List of expired user IDs + """ + + def get_expired_users_txn(txn, now_ms): + # We need to use pattern matching as profiles.user_id is confusingly just the + # user's localpart, whereas account_validity.user_id is a full user ID + sql = """ + SELECT av.user_id from account_validity AS av + LEFT JOIN profiles as p + ON av.user_id LIKE '%%' || p.user_id || ':%%' + WHERE expiration_ts_ms <= ? + AND p.active = 1 + """ + txn.execute(sql, (now_ms,)) + rows = txn.fetchall() + + return [UserID.from_string(row[0]) for row in rows] + + res = yield self.db_pool.runInteraction( + "get_expired_users", get_expired_users_txn, self.clock.time_msec() + ) + return res + + @defer.inlineCallbacks def set_renewal_token_for_user(self, user_id, renewal_token): """Defines a renewal token for a given user. @@ -278,6 +326,55 @@ class RegistrationWorkerStore(SQLBaseStore): desc="delete_account_validity_for_user", ) + @defer.inlineCallbacks + def get_info_for_users( + self, user_ids: List[str], + ): + """Return the user info for a given set of users + + Args: + user_ids: A list of users to return information about + + Returns: + Deferred[Dict[str, bool]]: A dictionary mapping each user ID to + a dict with the following keys: + * expired - whether this is an expired user + * deactivated - whether this is a deactivated user + """ + # Get information of all our local users + def _get_info_for_users_txn(txn): + rows = [] + + for user_id in user_ids: + sql = """ + SELECT u.name, u.deactivated, av.expiration_ts_ms + FROM users as u + LEFT JOIN account_validity as av + ON av.user_id = u.name + WHERE u.name = ? + """ + + txn.execute(sql, (user_id,)) + row = txn.fetchone() + if row: + rows.append(row) + + return rows + + info_rows = yield self.db_pool.runInteraction( + "get_info_for_users", _get_info_for_users_txn + ) + + return { + user_id: { + "expired": ( + expiration is not None and self.clock.time_msec() >= expiration + ), + "deactivated": deactivated == 1, + } + for user_id, deactivated, expiration in info_rows + } + async def is_server_admin(self, user): """Determines if a user is an admin of this homeserver.