diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index eede8ae4d2..a78850259f 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -32,6 +32,7 @@ class RegistrationWorkerStore(SQLBaseStore):
super(RegistrationWorkerStore, self).__init__(db_conn, hs)
self.config = hs.config
+ self.clock = hs.get_clock()
@cached()
def get_user_by_id(self, user_id):
@@ -87,26 +88,157 @@ class RegistrationWorkerStore(SQLBaseStore):
)
@cachedInlineCallbacks()
- def get_expiration_ts_for_user(self, user):
+ def get_expiration_ts_for_user(self, user_id):
"""Get the expiration timestamp for the account bearing a given user ID.
Args:
- user (str): The ID of the user.
+ user_id (str): The ID of the user.
Returns:
defer.Deferred: None, if the account has no expiration timestamp,
- otherwise int representation of the timestamp (as a number of
- milliseconds since epoch).
+ otherwise int representation of the timestamp (as a number of
+ milliseconds since epoch).
"""
res = yield self._simple_select_one_onecol(
table="account_validity",
- keyvalues={"user_id": user.to_string()},
+ keyvalues={"user_id": user_id},
retcol="expiration_ts_ms",
allow_none=True,
- desc="get_expiration_date_for_user",
+ desc="get_expiration_ts_for_user",
+ )
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def renew_account_for_user(self, user_id, new_expiration_ts):
+ """Updates the account validity table with a new timestamp for a given
+ user, removes the existing renewal token from this user, and unsets the
+ flag indicating that an email has been sent for renewing this account.
+
+ Args:
+ user_id (str): ID of the user whose account validity to renew.
+ new_expiration_ts: New expiration date, as a timestamp in milliseconds
+ since epoch.
+ """
+ def renew_account_for_user_txn(txn):
+ self._simple_update_txn(
+ txn=txn,
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ updatevalues={
+ "expiration_ts_ms": new_expiration_ts,
+ "email_sent": False,
+ "renewal_token": None,
+ },
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_expiration_ts_for_user, (user_id,),
+ )
+
+ yield self.runInteraction(
+ "renew_account_for_user",
+ renew_account_for_user_txn,
+ )
+
+ @defer.inlineCallbacks
+ def set_renewal_token_for_user(self, user_id, renewal_token):
+ """Defines a renewal token for a given user.
+
+ Args:
+ user_id (str): ID of the user to set the renewal token for.
+ renewal_token (str): Random unique string that will be used to renew the
+ user's account.
+
+ Raises:
+ StoreError: The provided token is already set for another user.
+ """
+ yield self._simple_update_one(
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ updatevalues={"renewal_token": renewal_token},
+ desc="set_renewal_token_for_user",
+ )
+
+ @defer.inlineCallbacks
+ def get_user_from_renewal_token(self, renewal_token):
+ """Get a user ID from a renewal token.
+
+ Args:
+ renewal_token (str): The renewal token to perform the lookup with.
+
+ Returns:
+ defer.Deferred[str]: The ID of the user to which the token belongs.
+ """
+ res = yield self._simple_select_one_onecol(
+ table="account_validity",
+ keyvalues={"renewal_token": renewal_token},
+ retcol="user_id",
+ desc="get_user_from_renewal_token",
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def get_renewal_token_for_user(self, user_id):
+ """Get the renewal token associated with a given user ID.
+
+ Args:
+ user_id (str): The user ID to lookup a token for.
+
+ Returns:
+ defer.Deferred[str]: The renewal token associated with this user ID.
+ """
+ res = yield self._simple_select_one_onecol(
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ retcol="renewal_token",
+ desc="get_renewal_token_for_user",
)
+
defer.returnValue(res)
@defer.inlineCallbacks
+ def get_users_expiring_soon(self):
+ """Selects users whose account will expire in the [now, now + renew_at] time
+ window (see configuration for account_validity for information on what renew_at
+ refers to).
+
+ Returns:
+ Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
+ """
+ def select_users_txn(txn, now_ms, renew_at):
+ sql = (
+ "SELECT user_id, expiration_ts_ms FROM account_validity"
+ " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?"
+ )
+ values = [False, now_ms, renew_at]
+ txn.execute(sql, values)
+ return self.cursor_to_dict(txn)
+
+ res = yield self.runInteraction(
+ "get_users_expiring_soon",
+ select_users_txn,
+ self.clock.time_msec(), self.config.account_validity.renew_at,
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def set_renewal_mail_status(self, user_id, email_sent):
+ """Sets or unsets the flag that indicates whether a renewal email has been sent
+ to the user (and the user hasn't renewed their account yet).
+
+ Args:
+ user_id (str): ID of the user to set/unset the flag for.
+ email_sent (bool): Flag which indicates whether a renewal email has been sent
+ to this user.
+ """
+ yield self._simple_update_one(
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ updatevalues={"email_sent": email_sent},
+ desc="set_renewal_mail_status",
+ )
+
+ @defer.inlineCallbacks
def is_server_admin(self, user):
res = yield self._simple_select_one_onecol(
table="users",
@@ -508,22 +640,24 @@ class RegistrationStore(RegistrationWorkerStore,
}
)
- if self._account_validity.enabled:
- now_ms = self.clock.time_msec()
- expiration_ts = now_ms + self._account_validity.period
- self._simple_insert_txn(
- txn,
- "account_validity",
- values={
- "user_id": user_id,
- "expiration_ts_ms": expiration_ts,
- }
- )
except self.database_engine.module.IntegrityError:
raise StoreError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
)
+ if self._account_validity.enabled:
+ now_ms = self.clock.time_msec()
+ expiration_ts = now_ms + self._account_validity.period
+ self._simple_insert_txn(
+ txn,
+ "account_validity",
+ values={
+ "user_id": user_id,
+ "expiration_ts_ms": expiration_ts,
+ "email_sent": False,
+ }
+ )
+
if token:
# it's possible for this to get a conflict, but only for a single user
# since tokens are namespaced based on their user ID
diff --git a/synapse/storage/schema/delta/54/account_validity.sql b/synapse/storage/schema/delta/54/account_validity.sql
index 57249262d7..2357626000 100644
--- a/synapse/storage/schema/delta/54/account_validity.sql
+++ b/synapse/storage/schema/delta/54/account_validity.sql
@@ -13,8 +13,15 @@
* limitations under the License.
*/
+DROP TABLE IF EXISTS account_validity;
+
-- Track what users are in public rooms.
CREATE TABLE IF NOT EXISTS account_validity (
user_id TEXT PRIMARY KEY,
- expiration_ts_ms BIGINT NOT NULL
+ expiration_ts_ms BIGINT NOT NULL,
+ email_sent BOOLEAN NOT NULL,
+ renewal_token TEXT
);
+
+CREATE INDEX account_validity_email_sent_idx ON account_validity(email_sent, expiration_ts_ms)
+CREATE UNIQUE INDEX account_validity_renewal_string_idx ON account_validity(renewal_token)
|