diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index e30b86c346..03a06a83d6 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -32,6 +32,7 @@ class RegistrationWorkerStore(SQLBaseStore):
super(RegistrationWorkerStore, self).__init__(db_conn, hs)
self.config = hs.config
+ self.clock = hs.get_clock()
@cached()
def get_user_by_id(self, user_id):
@@ -86,6 +87,162 @@ class RegistrationWorkerStore(SQLBaseStore):
"get_user_by_access_token", self._query_for_auth, token
)
+ @cachedInlineCallbacks()
+ def get_expiration_ts_for_user(self, user_id):
+ """Get the expiration timestamp for the account bearing a given user ID.
+
+ Args:
+ user_id (str): The ID of the user.
+ Returns:
+ defer.Deferred: None, if the account has no expiration timestamp,
+ otherwise int representation of the timestamp (as a number of
+ milliseconds since epoch).
+ """
+ res = yield self._simple_select_one_onecol(
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ retcol="expiration_ts_ms",
+ allow_none=True,
+ desc="get_expiration_ts_for_user",
+ )
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def set_account_validity_for_user(self, user_id, expiration_ts, email_sent,
+ renewal_token=None):
+ """Updates the account validity properties of the given account, with the
+ given values.
+
+ Args:
+ user_id (str): ID of the account to update properties for.
+ expiration_ts (int): New expiration date, as a timestamp in milliseconds
+ since epoch.
+ email_sent (bool): True means a renewal email has been sent for this
+ account and there's no need to send another one for the current validity
+ period.
+ renewal_token (str): Renewal token the user can use to extend the validity
+ of their account. Defaults to no token.
+ """
+ def set_account_validity_for_user_txn(txn):
+ self._simple_update_txn(
+ txn=txn,
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ updatevalues={
+ "expiration_ts_ms": expiration_ts,
+ "email_sent": email_sent,
+ "renewal_token": renewal_token,
+ },
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_expiration_ts_for_user, (user_id,),
+ )
+
+ yield self.runInteraction(
+ "set_account_validity_for_user",
+ set_account_validity_for_user_txn,
+ )
+
+ @defer.inlineCallbacks
+ def set_renewal_token_for_user(self, user_id, renewal_token):
+ """Defines a renewal token for a given user.
+
+ Args:
+ user_id (str): ID of the user to set the renewal token for.
+ renewal_token (str): Random unique string that will be used to renew the
+ user's account.
+
+ Raises:
+ StoreError: The provided token is already set for another user.
+ """
+ yield self._simple_update_one(
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ updatevalues={"renewal_token": renewal_token},
+ desc="set_renewal_token_for_user",
+ )
+
+ @defer.inlineCallbacks
+ def get_user_from_renewal_token(self, renewal_token):
+ """Get a user ID from a renewal token.
+
+ Args:
+ renewal_token (str): The renewal token to perform the lookup with.
+
+ Returns:
+ defer.Deferred[str]: The ID of the user to which the token belongs.
+ """
+ res = yield self._simple_select_one_onecol(
+ table="account_validity",
+ keyvalues={"renewal_token": renewal_token},
+ retcol="user_id",
+ desc="get_user_from_renewal_token",
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def get_renewal_token_for_user(self, user_id):
+ """Get the renewal token associated with a given user ID.
+
+ Args:
+ user_id (str): The user ID to lookup a token for.
+
+ Returns:
+ defer.Deferred[str]: The renewal token associated with this user ID.
+ """
+ res = yield self._simple_select_one_onecol(
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ retcol="renewal_token",
+ desc="get_renewal_token_for_user",
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def get_users_expiring_soon(self):
+ """Selects users whose account will expire in the [now, now + renew_at] time
+ window (see configuration for account_validity for information on what renew_at
+ refers to).
+
+ Returns:
+ Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
+ """
+ def select_users_txn(txn, now_ms, renew_at):
+ sql = (
+ "SELECT user_id, expiration_ts_ms FROM account_validity"
+ " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?"
+ )
+ values = [False, now_ms, renew_at]
+ txn.execute(sql, values)
+ return self.cursor_to_dict(txn)
+
+ res = yield self.runInteraction(
+ "get_users_expiring_soon",
+ select_users_txn,
+ self.clock.time_msec(), self.config.account_validity.renew_at,
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def set_renewal_mail_status(self, user_id, email_sent):
+ """Sets or unsets the flag that indicates whether a renewal email has been sent
+ to the user (and the user hasn't renewed their account yet).
+
+ Args:
+ user_id (str): ID of the user to set/unset the flag for.
+ email_sent (bool): Flag which indicates whether a renewal email has been sent
+ to this user.
+ """
+ yield self._simple_update_one(
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ updatevalues={"email_sent": email_sent},
+ desc="set_renewal_mail_status",
+ )
+
@defer.inlineCallbacks
def is_server_admin(self, user):
res = yield self._simple_select_one_onecol(
@@ -425,6 +582,8 @@ class RegistrationStore(
columns=["creation_ts"],
)
+ self._account_validity = hs.config.account_validity
+
# we no longer use refresh tokens, but it's possible that some people
# might have a background update queued to build this index. Just
# clear the background update.
@@ -561,9 +720,23 @@ class RegistrationStore(
"user_type": user_type,
},
)
+
except self.database_engine.module.IntegrityError:
raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
+ if self._account_validity.enabled:
+ now_ms = self.clock.time_msec()
+ expiration_ts = now_ms + self._account_validity.period
+ self._simple_insert_txn(
+ txn,
+ "account_validity",
+ values={
+ "user_id": user_id,
+ "expiration_ts_ms": expiration_ts,
+ "email_sent": False,
+ }
+ )
+
if token:
# it's possible for this to get a conflict, but only for a single user
# since tokens are namespaced based on their user ID
|