diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index a83df7759d..f949308994 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -173,6 +173,37 @@ class RegistrationWorkerStore(SQLBaseStore):
"set_account_validity_for_user", set_account_validity_for_user_txn
)
+ async def get_expired_users(self):
+ """Get UserIDs of all expired users.
+
+ Users who are not active, or do not have profile information, are
+ excluded from the results.
+
+ Returns:
+ Deferred[List[UserID]]: List of expired user IDs
+ """
+
+ def get_expired_users_txn(txn, now_ms):
+ # We need to use pattern matching as profiles.user_id is confusingly just the
+ # user's localpart, whereas account_validity.user_id is a full user ID
+ sql = """
+ SELECT av.user_id from account_validity AS av
+ LEFT JOIN profiles as p
+ ON av.user_id LIKE '%%' || p.user_id || ':%%'
+ WHERE expiration_ts_ms <= ?
+ AND p.active = 1
+ """
+ txn.execute(sql, (now_ms,))
+ rows = txn.fetchall()
+
+ return [UserID.from_string(row[0]) for row in rows]
+
+ res = await self.db_pool.runInteraction(
+ "get_expired_users", get_expired_users_txn, self.clock.time_msec()
+ )
+
+ return res
+
async def set_renewal_token_for_user(
self, user_id: str, renewal_token: str
) -> None:
@@ -279,6 +310,54 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="delete_account_validity_for_user",
)
+ async def get_info_for_users(
+ self, user_ids: List[str],
+ ):
+ """Return the user info for a given set of users
+
+ Args:
+ user_ids: A list of users to return information about
+
+ Returns:
+ Deferred[Dict[str, bool]]: A dictionary mapping each user ID to
+ a dict with the following keys:
+ * expired - whether this is an expired user
+ * deactivated - whether this is a deactivated user
+ """
+ # Get information of all our local users
+ def _get_info_for_users_txn(txn):
+ rows = []
+
+ for user_id in user_ids:
+ sql = """
+ SELECT u.name, u.deactivated, av.expiration_ts_ms
+ FROM users as u
+ LEFT JOIN account_validity as av
+ ON av.user_id = u.name
+ WHERE u.name = ?
+ """
+
+ txn.execute(sql, (user_id,))
+ row = txn.fetchone()
+ if row:
+ rows.append(row)
+
+ return rows
+
+ info_rows = await self.db_pool.runInteraction(
+ "get_info_for_users", _get_info_for_users_txn
+ )
+
+ return {
+ user_id: {
+ "expired": (
+ expiration is not None and self.clock.time_msec() >= expiration
+ ),
+ "deactivated": deactivated == 1,
+ }
+ for user_id, deactivated, expiration in info_rows
+ }
+
async def is_server_admin(self, user: UserID) -> bool:
"""Determines if a user is an admin of this homeserver.
|