summary refs log tree commit diff
path: root/synapse/storage/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r--synapse/storage/registration.py173
1 files changed, 173 insertions, 0 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index e30b86c346..03a06a83d6 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -32,6 +32,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         super(RegistrationWorkerStore, self).__init__(db_conn, hs)
 
         self.config = hs.config
+        self.clock = hs.get_clock()
 
     @cached()
     def get_user_by_id(self, user_id):
@@ -86,6 +87,162 @@ class RegistrationWorkerStore(SQLBaseStore):
             "get_user_by_access_token", self._query_for_auth, token
         )
 
+    @cachedInlineCallbacks()
+    def get_expiration_ts_for_user(self, user_id):
+        """Get the expiration timestamp for the account bearing a given user ID.
+
+        Args:
+            user_id (str): The ID of the user.
+        Returns:
+            defer.Deferred: None, if the account has no expiration timestamp,
+                otherwise int representation of the timestamp (as a number of
+                milliseconds since epoch).
+        """
+        res = yield self._simple_select_one_onecol(
+            table="account_validity",
+            keyvalues={"user_id": user_id},
+            retcol="expiration_ts_ms",
+            allow_none=True,
+            desc="get_expiration_ts_for_user",
+        )
+        defer.returnValue(res)
+
+    @defer.inlineCallbacks
+    def set_account_validity_for_user(self, user_id, expiration_ts, email_sent,
+                                      renewal_token=None):
+        """Updates the account validity properties of the given account, with the
+        given values.
+
+        Args:
+            user_id (str): ID of the account to update properties for.
+            expiration_ts (int): New expiration date, as a timestamp in milliseconds
+                since epoch.
+            email_sent (bool): True means a renewal email has been sent for this
+                account and there's no need to send another one for the current validity
+                period.
+            renewal_token (str): Renewal token the user can use to extend the validity
+                of their account. Defaults to no token.
+        """
+        def set_account_validity_for_user_txn(txn):
+            self._simple_update_txn(
+                txn=txn,
+                table="account_validity",
+                keyvalues={"user_id": user_id},
+                updatevalues={
+                    "expiration_ts_ms": expiration_ts,
+                    "email_sent": email_sent,
+                    "renewal_token": renewal_token,
+                },
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.get_expiration_ts_for_user, (user_id,),
+            )
+
+        yield self.runInteraction(
+            "set_account_validity_for_user",
+            set_account_validity_for_user_txn,
+        )
+
+    @defer.inlineCallbacks
+    def set_renewal_token_for_user(self, user_id, renewal_token):
+        """Defines a renewal token for a given user.
+
+        Args:
+            user_id (str): ID of the user to set the renewal token for.
+            renewal_token (str): Random unique string that will be used to renew the
+                user's account.
+
+        Raises:
+            StoreError: The provided token is already set for another user.
+        """
+        yield self._simple_update_one(
+            table="account_validity",
+            keyvalues={"user_id": user_id},
+            updatevalues={"renewal_token": renewal_token},
+            desc="set_renewal_token_for_user",
+        )
+
+    @defer.inlineCallbacks
+    def get_user_from_renewal_token(self, renewal_token):
+        """Get a user ID from a renewal token.
+
+        Args:
+            renewal_token (str): The renewal token to perform the lookup with.
+
+        Returns:
+            defer.Deferred[str]: The ID of the user to which the token belongs.
+        """
+        res = yield self._simple_select_one_onecol(
+            table="account_validity",
+            keyvalues={"renewal_token": renewal_token},
+            retcol="user_id",
+            desc="get_user_from_renewal_token",
+        )
+
+        defer.returnValue(res)
+
+    @defer.inlineCallbacks
+    def get_renewal_token_for_user(self, user_id):
+        """Get the renewal token associated with a given user ID.
+
+        Args:
+            user_id (str): The user ID to lookup a token for.
+
+        Returns:
+            defer.Deferred[str]: The renewal token associated with this user ID.
+        """
+        res = yield self._simple_select_one_onecol(
+            table="account_validity",
+            keyvalues={"user_id": user_id},
+            retcol="renewal_token",
+            desc="get_renewal_token_for_user",
+        )
+
+        defer.returnValue(res)
+
+    @defer.inlineCallbacks
+    def get_users_expiring_soon(self):
+        """Selects users whose account will expire in the [now, now + renew_at] time
+        window (see configuration for account_validity for information on what renew_at
+        refers to).
+
+        Returns:
+            Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
+        """
+        def select_users_txn(txn, now_ms, renew_at):
+            sql = (
+                "SELECT user_id, expiration_ts_ms FROM account_validity"
+                " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?"
+            )
+            values = [False, now_ms, renew_at]
+            txn.execute(sql, values)
+            return self.cursor_to_dict(txn)
+
+        res = yield self.runInteraction(
+            "get_users_expiring_soon",
+            select_users_txn,
+            self.clock.time_msec(), self.config.account_validity.renew_at,
+        )
+
+        defer.returnValue(res)
+
+    @defer.inlineCallbacks
+    def set_renewal_mail_status(self, user_id, email_sent):
+        """Sets or unsets the flag that indicates whether a renewal email has been sent
+        to the user (and the user hasn't renewed their account yet).
+
+        Args:
+            user_id (str): ID of the user to set/unset the flag for.
+            email_sent (bool): Flag which indicates whether a renewal email has been sent
+                to this user.
+        """
+        yield self._simple_update_one(
+            table="account_validity",
+            keyvalues={"user_id": user_id},
+            updatevalues={"email_sent": email_sent},
+            desc="set_renewal_mail_status",
+        )
+
     @defer.inlineCallbacks
     def is_server_admin(self, user):
         res = yield self._simple_select_one_onecol(
@@ -425,6 +582,8 @@ class RegistrationStore(
             columns=["creation_ts"],
         )
 
+        self._account_validity = hs.config.account_validity
+
         # we no longer use refresh tokens, but it's possible that some people
         # might have a background update queued to build this index. Just
         # clear the background update.
@@ -561,9 +720,23 @@ class RegistrationStore(
                         "user_type": user_type,
                     },
                 )
+
         except self.database_engine.module.IntegrityError:
             raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
 
+        if self._account_validity.enabled:
+            now_ms = self.clock.time_msec()
+            expiration_ts = now_ms + self._account_validity.period
+            self._simple_insert_txn(
+                txn,
+                "account_validity",
+                values={
+                    "user_id": user_id,
+                    "expiration_ts_ms": expiration_ts,
+                    "email_sent": False,
+                }
+            )
+
         if token:
             # it's possible for this to get a conflict, but only for a single user
             # since tokens are namespaced based on their user ID