summary refs log tree commit diff
path: root/synapse/storage/databases/main/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/registration.py')
-rw-r--r--synapse/storage/databases/main/registration.py136
1 files changed, 120 insertions, 16 deletions
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py

index 90a8f664ef..c439b58db6 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -92,13 +92,19 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): id_column=None, ) - self._account_validity = hs.config.account_validity - if hs.config.run_background_tasks and self._account_validity.enabled: - self._clock.call_later( - 0.0, - self._set_expiration_date_when_missing, + self._account_validity_enabled = hs.config.account_validity_enabled + if self._account_validity_enabled: + self._account_validity_period = hs.config.account_validity_period + self._account_validity_startup_job_max_delta = ( + hs.config.account_validity_startup_job_max_delta ) + if hs.config.run_background_tasks: + self._clock.call_later( + 0.0, + self._set_expiration_date_when_missing, + ) + # Create a background job for culling expired 3PID validity tokens if hs.config.run_background_tasks: self._clock.looping_call( @@ -195,6 +201,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): expiration_ts: int, email_sent: bool, renewal_token: Optional[str] = None, + token_used_ts: Optional[int] = None, ) -> None: """Updates the account validity properties of the given account, with the given values. @@ -208,6 +215,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): period. renewal_token: Renewal token the user can use to extend the validity of their account. Defaults to no token. + token_used_ts: A timestamp of when the current token was used to renew + the account. """ def set_account_validity_for_user_txn(txn): @@ -219,6 +228,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "expiration_ts_ms": expiration_ts, "email_sent": email_sent, "renewal_token": renewal_token, + "token_used_ts_ms": token_used_ts, }, ) self._invalidate_cache_and_stream( @@ -229,10 +239,41 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "set_account_validity_for_user", set_account_validity_for_user_txn ) + async def get_expired_users(self): + """Get UserIDs of all expired users. + + Users who are not active, or do not have profile information, are + excluded from the results. + + Returns: + Deferred[List[UserID]]: List of expired user IDs + """ + + def get_expired_users_txn(txn, now_ms): + # We need to use pattern matching as profiles.user_id is confusingly just the + # user's localpart, whereas account_validity.user_id is a full user ID + sql = """ + SELECT av.user_id from account_validity AS av + LEFT JOIN profiles as p + ON av.user_id LIKE '%%' || p.user_id || ':%%' + WHERE expiration_ts_ms <= ? + AND p.active = 1 + """ + txn.execute(sql, (now_ms,)) + rows = txn.fetchall() + + return [UserID.from_string(row[0]) for row in rows] + + res = await self.db_pool.runInteraction( + "get_expired_users", get_expired_users_txn, self._clock.time_msec() + ) + + return res + async def set_renewal_token_for_user( self, user_id: str, renewal_token: str ) -> None: - """Defines a renewal token for a given user. + """Defines a renewal token for a given user, and clears the token_used timestamp. Args: user_id: ID of the user to set the renewal token for. @@ -245,26 +286,40 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): await self.db_pool.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, - updatevalues={"renewal_token": renewal_token}, + updatevalues={"renewal_token": renewal_token, "token_used_ts_ms": None}, desc="set_renewal_token_for_user", ) - async def get_user_from_renewal_token(self, renewal_token: str) -> str: - """Get a user ID from a renewal token. + async def get_user_from_renewal_token( + self, renewal_token: str + ) -> Tuple[str, int, Optional[int]]: + """Get a user ID and renewal status from a renewal token. Args: renewal_token: The renewal token to perform the lookup with. Returns: - The ID of the user to which the token belongs. + A tuple of containing the following values: + * The ID of a user to which the token belongs. + * An int representing the user's expiry timestamp as milliseconds since the + epoch, or 0 if the token was invalid. + * An optional int representing the timestamp of when the user renewed their + account timestamp as milliseconds since the epoch. None if the account + has not been renewed using the current token yet. """ - return await self.db_pool.simple_select_one_onecol( + ret_dict = await self.db_pool.simple_select_one( table="account_validity", keyvalues={"renewal_token": renewal_token}, - retcol="user_id", + retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"], desc="get_user_from_renewal_token", ) + return ( + ret_dict["user_id"], + ret_dict["expiration_ts_ms"], + ret_dict["token_used_ts_ms"], + ) + async def get_renewal_token_for_user(self, user_id: str) -> str: """Get the renewal token associated with a given user ID. @@ -303,7 +358,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "get_users_expiring_soon", select_users_txn, self._clock.time_msec(), - self.config.account_validity.renew_at, + self.config.account_validity_renew_at, ) async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None: @@ -335,6 +390,55 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="delete_account_validity_for_user", ) + async def get_info_for_users( + self, + user_ids: List[str], + ): + """Return the user info for a given set of users + + Args: + user_ids: A list of users to return information about + + Returns: + Deferred[Dict[str, bool]]: A dictionary mapping each user ID to + a dict with the following keys: + * expired - whether this is an expired user + * deactivated - whether this is a deactivated user + """ + # Get information of all our local users + def _get_info_for_users_txn(txn): + rows = [] + + for user_id in user_ids: + sql = """ + SELECT u.name, u.deactivated, av.expiration_ts_ms + FROM users as u + LEFT JOIN account_validity as av + ON av.user_id = u.name + WHERE u.name = ? + """ + + txn.execute(sql, (user_id,)) + row = txn.fetchone() + if row: + rows.append(row) + + return rows + + info_rows = await self.db_pool.runInteraction( + "get_info_for_users", _get_info_for_users_txn + ) + + return { + user_id: { + "expired": ( + expiration is not None and self._clock.time_msec() >= expiration + ), + "deactivated": deactivated == 1, + } + for user_id, deactivated, expiration in info_rows + } + async def is_server_admin(self, user: UserID) -> bool: """Determines if a user is an admin of this homeserver. @@ -965,11 +1069,11 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): delta equal to 10% of the validity period. """ now_ms = self._clock.time_msec() - expiration_ts = now_ms + self._account_validity.period + expiration_ts = now_ms + self._account_validity_period if use_delta: expiration_ts = self.rand.randrange( - expiration_ts - self._account_validity.startup_job_max_delta, + expiration_ts - self._account_validity_startup_job_max_delta, expiration_ts, ) @@ -1413,7 +1517,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): except self.database_engine.module.IntegrityError: raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) - if self._account_validity.enabled: + if self._account_validity_enabled: self.set_expiration_date_for_user_txn(txn, user_id) if create_profile_with_displayname: