summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/registration.py168
-rw-r--r--synapse/storage/schema/delta/54/account_validity.sql9
2 files changed, 159 insertions, 18 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 643f7a3808..a1085ad80c 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -32,6 +32,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         super(RegistrationWorkerStore, self).__init__(db_conn, hs)
 
         self.config = hs.config
+        self.clock = hs.get_clock()
 
     @cached()
     def get_user_by_id(self, user_id):
@@ -87,26 +88,157 @@ class RegistrationWorkerStore(SQLBaseStore):
         )
 
     @cachedInlineCallbacks()
-    def get_expiration_ts_for_user(self, user):
+    def get_expiration_ts_for_user(self, user_id):
         """Get the expiration timestamp for the account bearing a given user ID.
 
         Args:
-            user (str): The ID of the user.
+            user_id (str): The ID of the user.
         Returns:
             defer.Deferred: None, if the account has no expiration timestamp,
-            otherwise int representation of the timestamp (as a number of
-            milliseconds since epoch).
+                otherwise int representation of the timestamp (as a number of
+                milliseconds since epoch).
         """
         res = yield self._simple_select_one_onecol(
             table="account_validity",
-            keyvalues={"user_id": user.to_string()},
+            keyvalues={"user_id": user_id},
             retcol="expiration_ts_ms",
             allow_none=True,
-            desc="get_expiration_date_for_user",
+            desc="get_expiration_ts_for_user",
+        )
+        defer.returnValue(res)
+
+    @defer.inlineCallbacks
+    def renew_account_for_user(self, user_id, new_expiration_ts):
+        """Updates the account validity table with a new timestamp for a given
+        user, removes the existing renewal token from this user, and unsets the
+        flag indicating that an email has been sent for renewing this account.
+
+        Args:
+            user_id (str): ID of the user whose account validity to renew.
+            new_expiration_ts: New expiration date, as a timestamp in milliseconds
+                since epoch.
+        """
+        def renew_account_for_user_txn(txn):
+            self._simple_update_txn(
+                txn=txn,
+                table="account_validity",
+                keyvalues={"user_id": user_id},
+                updatevalues={
+                    "expiration_ts_ms": new_expiration_ts,
+                    "email_sent": False,
+                    "renewal_token": None,
+                },
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.get_expiration_ts_for_user, (user_id,),
+            )
+
+        yield self.runInteraction(
+            "renew_account_for_user",
+            renew_account_for_user_txn,
+        )
+
+    @defer.inlineCallbacks
+    def set_renewal_token_for_user(self, user_id, renewal_token):
+        """Defines a renewal token for a given user.
+
+        Args:
+            user_id (str): ID of the user to set the renewal token for.
+            renewal_token (str): Random unique string that will be used to renew the
+                user's account.
+
+        Raises:
+            StoreError: The provided token is already set for another user.
+        """
+        yield self._simple_update_one(
+            table="account_validity",
+            keyvalues={"user_id": user_id},
+            updatevalues={"renewal_token": renewal_token},
+            desc="set_renewal_token_for_user",
         )
+
+    @defer.inlineCallbacks
+    def get_user_from_renewal_token(self, renewal_token):
+        """Get a user ID from a renewal token.
+
+        Args:
+            renewal_token (str): The renewal token to perform the lookup with.
+
+        Returns:
+            defer.Deferred[str]: The ID of the user to which the token belongs.
+        """
+        res = yield self._simple_select_one_onecol(
+            table="account_validity",
+            keyvalues={"renewal_token": renewal_token},
+            retcol="user_id",
+            desc="get_user_from_renewal_token",
+        )
+
         defer.returnValue(res)
 
     @defer.inlineCallbacks
+    def get_renewal_token_for_user(self, user_id):
+        """Get the renewal token associated with a given user ID.
+
+        Args:
+            user_id (str): The user ID to lookup a token for.
+
+        Returns:
+            defer.Deferred[str]: The renewal token associated with this user ID.
+        """
+        res = yield self._simple_select_one_onecol(
+            table="account_validity",
+            keyvalues={"user_id": user_id},
+            retcol="renewal_token",
+            desc="get_renewal_token_for_user",
+        )
+
+        defer.returnValue(res)
+
+    @defer.inlineCallbacks
+    def get_users_expiring_soon(self):
+        """Selects users whose account will expire in the [now, now + renew_at] time
+        window (see configuration for account_validity for information on what renew_at
+        refers to).
+
+        Returns:
+            Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
+        """
+        def select_users_txn(txn, now_ms, renew_at):
+            sql = (
+                "SELECT user_id, expiration_ts_ms FROM account_validity"
+                " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?"
+            )
+            values = [False, now_ms, renew_at]
+            txn.execute(sql, values)
+            return self.cursor_to_dict(txn)
+
+        res = yield self.runInteraction(
+            "get_users_expiring_soon",
+            select_users_txn,
+            self.clock.time_msec(), self.config.account_validity.renew_at,
+        )
+
+        defer.returnValue(res)
+
+    @defer.inlineCallbacks
+    def set_renewal_mail_status(self, user_id, email_sent):
+        """Sets or unsets the flag that indicates whether a renewal email has been sent
+        to the user (and the user hasn't renewed their account yet).
+
+        Args:
+            user_id (str): ID of the user to set/unset the flag for.
+            email_sent (bool): Flag which indicates whether a renewal email has been sent
+                to this user.
+        """
+        yield self._simple_update_one(
+            table="account_validity",
+            keyvalues={"user_id": user_id},
+            updatevalues={"email_sent": email_sent},
+            desc="set_renewal_mail_status",
+        )
+
+    @defer.inlineCallbacks
     def is_server_admin(self, user):
         res = yield self._simple_select_one_onecol(
             table="users",
@@ -584,20 +716,22 @@ class RegistrationStore(
                     },
                 )
 
-            if self._account_validity.enabled:
-                now_ms = self.clock.time_msec()
-                expiration_ts = now_ms + self._account_validity.period
-                self._simple_insert_txn(
-                    txn,
-                    "account_validity",
-                    values={
-                        "user_id": user_id,
-                        "expiration_ts_ms": expiration_ts,
-                    }
-                )
         except self.database_engine.module.IntegrityError:
             raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
 
+        if self._account_validity.enabled:
+            now_ms = self.clock.time_msec()
+            expiration_ts = now_ms + self._account_validity.period
+            self._simple_insert_txn(
+                txn,
+                "account_validity",
+                values={
+                    "user_id": user_id,
+                    "expiration_ts_ms": expiration_ts,
+                    "email_sent": False,
+                }
+            )
+
         if token:
             # it's possible for this to get a conflict, but only for a single user
             # since tokens are namespaced based on their user ID
diff --git a/synapse/storage/schema/delta/54/account_validity.sql b/synapse/storage/schema/delta/54/account_validity.sql
index 57249262d7..2357626000 100644
--- a/synapse/storage/schema/delta/54/account_validity.sql
+++ b/synapse/storage/schema/delta/54/account_validity.sql
@@ -13,8 +13,15 @@
  * limitations under the License.
  */
 
+DROP TABLE IF EXISTS account_validity;
+
 -- Track what users are in public rooms.
 CREATE TABLE IF NOT EXISTS account_validity (
     user_id TEXT PRIMARY KEY,
-    expiration_ts_ms BIGINT NOT NULL
+    expiration_ts_ms BIGINT NOT NULL,
+    email_sent BOOLEAN NOT NULL,
+    renewal_token TEXT
 );
+
+CREATE INDEX account_validity_email_sent_idx ON account_validity(email_sent, expiration_ts_ms)
+CREATE UNIQUE INDEX account_validity_renewal_string_idx ON account_validity(renewal_token)