diff options
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r-- | synapse/storage/registration.py | 168 |
1 files changed, 151 insertions, 17 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 643f7a3808..a1085ad80c 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -32,6 +32,7 @@ class RegistrationWorkerStore(SQLBaseStore): super(RegistrationWorkerStore, self).__init__(db_conn, hs) self.config = hs.config + self.clock = hs.get_clock() @cached() def get_user_by_id(self, user_id): @@ -87,26 +88,157 @@ class RegistrationWorkerStore(SQLBaseStore): ) @cachedInlineCallbacks() - def get_expiration_ts_for_user(self, user): + def get_expiration_ts_for_user(self, user_id): """Get the expiration timestamp for the account bearing a given user ID. Args: - user (str): The ID of the user. + user_id (str): The ID of the user. Returns: defer.Deferred: None, if the account has no expiration timestamp, - otherwise int representation of the timestamp (as a number of - milliseconds since epoch). + otherwise int representation of the timestamp (as a number of + milliseconds since epoch). """ res = yield self._simple_select_one_onecol( table="account_validity", - keyvalues={"user_id": user.to_string()}, + keyvalues={"user_id": user_id}, retcol="expiration_ts_ms", allow_none=True, - desc="get_expiration_date_for_user", + desc="get_expiration_ts_for_user", + ) + defer.returnValue(res) + + @defer.inlineCallbacks + def renew_account_for_user(self, user_id, new_expiration_ts): + """Updates the account validity table with a new timestamp for a given + user, removes the existing renewal token from this user, and unsets the + flag indicating that an email has been sent for renewing this account. + + Args: + user_id (str): ID of the user whose account validity to renew. + new_expiration_ts: New expiration date, as a timestamp in milliseconds + since epoch. + """ + def renew_account_for_user_txn(txn): + self._simple_update_txn( + txn=txn, + table="account_validity", + keyvalues={"user_id": user_id}, + updatevalues={ + "expiration_ts_ms": new_expiration_ts, + "email_sent": False, + "renewal_token": None, + }, + ) + self._invalidate_cache_and_stream( + txn, self.get_expiration_ts_for_user, (user_id,), + ) + + yield self.runInteraction( + "renew_account_for_user", + renew_account_for_user_txn, + ) + + @defer.inlineCallbacks + def set_renewal_token_for_user(self, user_id, renewal_token): + """Defines a renewal token for a given user. + + Args: + user_id (str): ID of the user to set the renewal token for. + renewal_token (str): Random unique string that will be used to renew the + user's account. + + Raises: + StoreError: The provided token is already set for another user. + """ + yield self._simple_update_one( + table="account_validity", + keyvalues={"user_id": user_id}, + updatevalues={"renewal_token": renewal_token}, + desc="set_renewal_token_for_user", ) + + @defer.inlineCallbacks + def get_user_from_renewal_token(self, renewal_token): + """Get a user ID from a renewal token. + + Args: + renewal_token (str): The renewal token to perform the lookup with. + + Returns: + defer.Deferred[str]: The ID of the user to which the token belongs. + """ + res = yield self._simple_select_one_onecol( + table="account_validity", + keyvalues={"renewal_token": renewal_token}, + retcol="user_id", + desc="get_user_from_renewal_token", + ) + defer.returnValue(res) @defer.inlineCallbacks + def get_renewal_token_for_user(self, user_id): + """Get the renewal token associated with a given user ID. + + Args: + user_id (str): The user ID to lookup a token for. + + Returns: + defer.Deferred[str]: The renewal token associated with this user ID. + """ + res = yield self._simple_select_one_onecol( + table="account_validity", + keyvalues={"user_id": user_id}, + retcol="renewal_token", + desc="get_renewal_token_for_user", + ) + + defer.returnValue(res) + + @defer.inlineCallbacks + def get_users_expiring_soon(self): + """Selects users whose account will expire in the [now, now + renew_at] time + window (see configuration for account_validity for information on what renew_at + refers to). + + Returns: + Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]] + """ + def select_users_txn(txn, now_ms, renew_at): + sql = ( + "SELECT user_id, expiration_ts_ms FROM account_validity" + " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?" + ) + values = [False, now_ms, renew_at] + txn.execute(sql, values) + return self.cursor_to_dict(txn) + + res = yield self.runInteraction( + "get_users_expiring_soon", + select_users_txn, + self.clock.time_msec(), self.config.account_validity.renew_at, + ) + + defer.returnValue(res) + + @defer.inlineCallbacks + def set_renewal_mail_status(self, user_id, email_sent): + """Sets or unsets the flag that indicates whether a renewal email has been sent + to the user (and the user hasn't renewed their account yet). + + Args: + user_id (str): ID of the user to set/unset the flag for. + email_sent (bool): Flag which indicates whether a renewal email has been sent + to this user. + """ + yield self._simple_update_one( + table="account_validity", + keyvalues={"user_id": user_id}, + updatevalues={"email_sent": email_sent}, + desc="set_renewal_mail_status", + ) + + @defer.inlineCallbacks def is_server_admin(self, user): res = yield self._simple_select_one_onecol( table="users", @@ -584,20 +716,22 @@ class RegistrationStore( }, ) - if self._account_validity.enabled: - now_ms = self.clock.time_msec() - expiration_ts = now_ms + self._account_validity.period - self._simple_insert_txn( - txn, - "account_validity", - values={ - "user_id": user_id, - "expiration_ts_ms": expiration_ts, - } - ) except self.database_engine.module.IntegrityError: raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) + if self._account_validity.enabled: + now_ms = self.clock.time_msec() + expiration_ts = now_ms + self._account_validity.period + self._simple_insert_txn( + txn, + "account_validity", + values={ + "user_id": user_id, + "expiration_ts_ms": expiration_ts, + "email_sent": False, + } + ) + if token: # it's possible for this to get a conflict, but only for a single user # since tokens are namespaced based on their user ID |