summary refs log tree commit diff
path: root/synapse/storage/registration.py
diff options
context:
space:
mode:
authorBrendan Abolivier <contact@brendanabolivier.com>2019-04-17 14:57:39 +0100
committerGitHub <noreply@github.com>2019-04-17 14:57:39 +0100
commit91934025b93bec62c8f5bf12f1975b0c6bffde93 (patch)
tree0bd1029b31353d10868fa3cd3b3b22b70568e472 /synapse/storage/registration.py
parentMerge pull request #5071 from matrix-org/babolivier/3pid-check (diff)
parentSend out emails with links to extend an account's validity period (diff)
downloadsynapse-91934025b93bec62c8f5bf12f1975b0c6bffde93.tar.xz
Merge pull request #5047 from matrix-org/babolivier/account_expiration
Send out emails with links to extend an account's validity period
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r--synapse/storage/registration.py168
1 files changed, 151 insertions, 17 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 643f7a3808..a1085ad80c 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -32,6 +32,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         super(RegistrationWorkerStore, self).__init__(db_conn, hs)
 
         self.config = hs.config
+        self.clock = hs.get_clock()
 
     @cached()
     def get_user_by_id(self, user_id):
@@ -87,26 +88,157 @@ class RegistrationWorkerStore(SQLBaseStore):
         )
 
     @cachedInlineCallbacks()
-    def get_expiration_ts_for_user(self, user):
+    def get_expiration_ts_for_user(self, user_id):
         """Get the expiration timestamp for the account bearing a given user ID.
 
         Args:
-            user (str): The ID of the user.
+            user_id (str): The ID of the user.
         Returns:
             defer.Deferred: None, if the account has no expiration timestamp,
-            otherwise int representation of the timestamp (as a number of
-            milliseconds since epoch).
+                otherwise int representation of the timestamp (as a number of
+                milliseconds since epoch).
         """
         res = yield self._simple_select_one_onecol(
             table="account_validity",
-            keyvalues={"user_id": user.to_string()},
+            keyvalues={"user_id": user_id},
             retcol="expiration_ts_ms",
             allow_none=True,
-            desc="get_expiration_date_for_user",
+            desc="get_expiration_ts_for_user",
+        )
+        defer.returnValue(res)
+
+    @defer.inlineCallbacks
+    def renew_account_for_user(self, user_id, new_expiration_ts):
+        """Updates the account validity table with a new timestamp for a given
+        user, removes the existing renewal token from this user, and unsets the
+        flag indicating that an email has been sent for renewing this account.
+
+        Args:
+            user_id (str): ID of the user whose account validity to renew.
+            new_expiration_ts: New expiration date, as a timestamp in milliseconds
+                since epoch.
+        """
+        def renew_account_for_user_txn(txn):
+            self._simple_update_txn(
+                txn=txn,
+                table="account_validity",
+                keyvalues={"user_id": user_id},
+                updatevalues={
+                    "expiration_ts_ms": new_expiration_ts,
+                    "email_sent": False,
+                    "renewal_token": None,
+                },
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.get_expiration_ts_for_user, (user_id,),
+            )
+
+        yield self.runInteraction(
+            "renew_account_for_user",
+            renew_account_for_user_txn,
+        )
+
+    @defer.inlineCallbacks
+    def set_renewal_token_for_user(self, user_id, renewal_token):
+        """Defines a renewal token for a given user.
+
+        Args:
+            user_id (str): ID of the user to set the renewal token for.
+            renewal_token (str): Random unique string that will be used to renew the
+                user's account.
+
+        Raises:
+            StoreError: The provided token is already set for another user.
+        """
+        yield self._simple_update_one(
+            table="account_validity",
+            keyvalues={"user_id": user_id},
+            updatevalues={"renewal_token": renewal_token},
+            desc="set_renewal_token_for_user",
         )
+
+    @defer.inlineCallbacks
+    def get_user_from_renewal_token(self, renewal_token):
+        """Get a user ID from a renewal token.
+
+        Args:
+            renewal_token (str): The renewal token to perform the lookup with.
+
+        Returns:
+            defer.Deferred[str]: The ID of the user to which the token belongs.
+        """
+        res = yield self._simple_select_one_onecol(
+            table="account_validity",
+            keyvalues={"renewal_token": renewal_token},
+            retcol="user_id",
+            desc="get_user_from_renewal_token",
+        )
+
         defer.returnValue(res)
 
     @defer.inlineCallbacks
+    def get_renewal_token_for_user(self, user_id):
+        """Get the renewal token associated with a given user ID.
+
+        Args:
+            user_id (str): The user ID to lookup a token for.
+
+        Returns:
+            defer.Deferred[str]: The renewal token associated with this user ID.
+        """
+        res = yield self._simple_select_one_onecol(
+            table="account_validity",
+            keyvalues={"user_id": user_id},
+            retcol="renewal_token",
+            desc="get_renewal_token_for_user",
+        )
+
+        defer.returnValue(res)
+
+    @defer.inlineCallbacks
+    def get_users_expiring_soon(self):
+        """Selects users whose account will expire in the [now, now + renew_at] time
+        window (see configuration for account_validity for information on what renew_at
+        refers to).
+
+        Returns:
+            Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
+        """
+        def select_users_txn(txn, now_ms, renew_at):
+            sql = (
+                "SELECT user_id, expiration_ts_ms FROM account_validity"
+                " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?"
+            )
+            values = [False, now_ms, renew_at]
+            txn.execute(sql, values)
+            return self.cursor_to_dict(txn)
+
+        res = yield self.runInteraction(
+            "get_users_expiring_soon",
+            select_users_txn,
+            self.clock.time_msec(), self.config.account_validity.renew_at,
+        )
+
+        defer.returnValue(res)
+
+    @defer.inlineCallbacks
+    def set_renewal_mail_status(self, user_id, email_sent):
+        """Sets or unsets the flag that indicates whether a renewal email has been sent
+        to the user (and the user hasn't renewed their account yet).
+
+        Args:
+            user_id (str): ID of the user to set/unset the flag for.
+            email_sent (bool): Flag which indicates whether a renewal email has been sent
+                to this user.
+        """
+        yield self._simple_update_one(
+            table="account_validity",
+            keyvalues={"user_id": user_id},
+            updatevalues={"email_sent": email_sent},
+            desc="set_renewal_mail_status",
+        )
+
+    @defer.inlineCallbacks
     def is_server_admin(self, user):
         res = yield self._simple_select_one_onecol(
             table="users",
@@ -584,20 +716,22 @@ class RegistrationStore(
                     },
                 )
 
-            if self._account_validity.enabled:
-                now_ms = self.clock.time_msec()
-                expiration_ts = now_ms + self._account_validity.period
-                self._simple_insert_txn(
-                    txn,
-                    "account_validity",
-                    values={
-                        "user_id": user_id,
-                        "expiration_ts_ms": expiration_ts,
-                    }
-                )
         except self.database_engine.module.IntegrityError:
             raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
 
+        if self._account_validity.enabled:
+            now_ms = self.clock.time_msec()
+            expiration_ts = now_ms + self._account_validity.period
+            self._simple_insert_txn(
+                txn,
+                "account_validity",
+                values={
+                    "user_id": user_id,
+                    "expiration_ts_ms": expiration_ts,
+                    "email_sent": False,
+                }
+            )
+
         if token:
             # it's possible for this to get a conflict, but only for a single user
             # since tokens are namespaced based on their user ID