summary refs log tree commit diff
path: root/synapse/storage/databases/main/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/registration.py')
-rw-r--r--synapse/storage/databases/main/registration.py62
1 files changed, 46 insertions, 16 deletions
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 833214b7e0..6e5ee557d2 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -91,13 +91,25 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             id_column=None,
         )
 
-        self._account_validity = hs.config.account_validity
-        if hs.config.run_background_tasks and self._account_validity.enabled:
-            self._clock.call_later(
-                0.0,
-                self._set_expiration_date_when_missing,
+        self._account_validity_enabled = (
+            hs.config.account_validity.account_validity_enabled
+        )
+        self._account_validity_period = None
+        self._account_validity_startup_job_max_delta = None
+        if self._account_validity_enabled:
+            self._account_validity_period = (
+                hs.config.account_validity.account_validity_period
+            )
+            self._account_validity_startup_job_max_delta = (
+                hs.config.account_validity.account_validity_startup_job_max_delta
             )
 
+            if hs.config.run_background_tasks:
+                self._clock.call_later(
+                    0.0,
+                    self._set_expiration_date_when_missing,
+                )
+
         # Create a background job for culling expired 3PID validity tokens
         if hs.config.run_background_tasks:
             self._clock.looping_call(
@@ -194,6 +206,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         expiration_ts: int,
         email_sent: bool,
         renewal_token: Optional[str] = None,
+        token_used_ts: Optional[int] = None,
     ) -> None:
         """Updates the account validity properties of the given account, with the
         given values.
@@ -207,6 +220,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 period.
             renewal_token: Renewal token the user can use to extend the validity
                 of their account. Defaults to no token.
+            token_used_ts: A timestamp of when the current token was used to renew
+                the account.
         """
 
         def set_account_validity_for_user_txn(txn):
@@ -218,6 +233,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                     "expiration_ts_ms": expiration_ts,
                     "email_sent": email_sent,
                     "renewal_token": renewal_token,
+                    "token_used_ts_ms": token_used_ts,
                 },
             )
             self._invalidate_cache_and_stream(
@@ -231,7 +247,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     async def set_renewal_token_for_user(
         self, user_id: str, renewal_token: str
     ) -> None:
-        """Defines a renewal token for a given user.
+        """Defines a renewal token for a given user, and clears the token_used timestamp.
 
         Args:
             user_id: ID of the user to set the renewal token for.
@@ -244,26 +260,40 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         await self.db_pool.simple_update_one(
             table="account_validity",
             keyvalues={"user_id": user_id},
-            updatevalues={"renewal_token": renewal_token},
+            updatevalues={"renewal_token": renewal_token, "token_used_ts_ms": None},
             desc="set_renewal_token_for_user",
         )
 
-    async def get_user_from_renewal_token(self, renewal_token: str) -> str:
-        """Get a user ID from a renewal token.
+    async def get_user_from_renewal_token(
+        self, renewal_token: str
+    ) -> Tuple[str, int, Optional[int]]:
+        """Get a user ID and renewal status from a renewal token.
 
         Args:
             renewal_token: The renewal token to perform the lookup with.
 
         Returns:
-            The ID of the user to which the token belongs.
+            A tuple of containing the following values:
+                * The ID of a user to which the token belongs.
+                * An int representing the user's expiry timestamp as milliseconds since the
+                    epoch, or 0 if the token was invalid.
+                * An optional int representing the timestamp of when the user renewed their
+                    account timestamp as milliseconds since the epoch. None if the account
+                    has not been renewed using the current token yet.
         """
-        return await self.db_pool.simple_select_one_onecol(
+        ret_dict = await self.db_pool.simple_select_one(
             table="account_validity",
             keyvalues={"renewal_token": renewal_token},
-            retcol="user_id",
+            retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
             desc="get_user_from_renewal_token",
         )
 
+        return (
+            ret_dict["user_id"],
+            ret_dict["expiration_ts_ms"],
+            ret_dict["token_used_ts_ms"],
+        )
+
     async def get_renewal_token_for_user(self, user_id: str) -> str:
         """Get the renewal token associated with a given user ID.
 
@@ -302,7 +332,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             "get_users_expiring_soon",
             select_users_txn,
             self._clock.time_msec(),
-            self.config.account_validity.renew_at,
+            self.config.account_validity_renew_at,
         )
 
     async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None:
@@ -964,11 +994,11 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 delta equal to 10% of the validity period.
         """
         now_ms = self._clock.time_msec()
-        expiration_ts = now_ms + self._account_validity.period
+        expiration_ts = now_ms + self._account_validity_period
 
         if use_delta:
             expiration_ts = self.rand.randrange(
-                expiration_ts - self._account_validity.startup_job_max_delta,
+                expiration_ts - self._account_validity_startup_job_max_delta,
                 expiration_ts,
             )
 
@@ -1412,7 +1442,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
         except self.database_engine.module.IntegrityError:
             raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
 
-        if self._account_validity.enabled:
+        if self._account_validity_enabled:
             self.set_expiration_date_for_user_txn(txn, user_id)
 
         if create_profile_with_displayname: