summary refs log tree commit diff
path: root/synapse/storage/databases/main/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/registration.py')
-rw-r--r--synapse/storage/databases/main/registration.py184
1 files changed, 91 insertions, 93 deletions
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 16ba545740..a85867936f 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -14,14 +14,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 import re
 from typing import Any, Dict, List, Optional, Tuple
 
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+    run_as_background_process,
+    wrap_as_background_process,
+)
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
 from synapse.storage.types import Cursor
@@ -48,6 +50,21 @@ class RegistrationWorkerStore(SQLBaseStore):
             database.engine, find_max_generated_user_id_localpart, "user_id_seq",
         )
 
+        self._account_validity = hs.config.account_validity
+        if hs.config.run_background_tasks and self._account_validity.enabled:
+            self._clock.call_later(
+                0.0,
+                run_as_background_process,
+                "account_validity_set_expiration_dates",
+                self._set_expiration_date_when_missing,
+            )
+
+        # Create a background job for culling expired 3PID validity tokens
+        if hs.config.run_background_tasks:
+            self.clock.looping_call(
+                self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
+            )
+
     @cached()
     async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
         return await self.db_pool.simple_select_one(
@@ -778,6 +795,78 @@ class RegistrationWorkerStore(SQLBaseStore):
             "delete_threepid_session", delete_threepid_session_txn
         )
 
+    @wrap_as_background_process("cull_expired_threepid_validation_tokens")
+    async def cull_expired_threepid_validation_tokens(self) -> None:
+        """Remove threepid validation tokens with expiry dates that have passed"""
+
+        def cull_expired_threepid_validation_tokens_txn(txn, ts):
+            sql = """
+            DELETE FROM threepid_validation_token WHERE
+            expires < ?
+            """
+            txn.execute(sql, (ts,))
+
+        await self.db_pool.runInteraction(
+            "cull_expired_threepid_validation_tokens",
+            cull_expired_threepid_validation_tokens_txn,
+            self.clock.time_msec(),
+        )
+
+    async def _set_expiration_date_when_missing(self):
+        """
+        Retrieves the list of registered users that don't have an expiration date, and
+        adds an expiration date for each of them.
+        """
+
+        def select_users_with_no_expiration_date_txn(txn):
+            """Retrieves the list of registered users with no expiration date from the
+            database, filtering out deactivated users.
+            """
+            sql = (
+                "SELECT users.name FROM users"
+                " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
+                " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
+            )
+            txn.execute(sql, [])
+
+            res = self.db_pool.cursor_to_dict(txn)
+            if res:
+                for user in res:
+                    self.set_expiration_date_for_user_txn(
+                        txn, user["name"], use_delta=True
+                    )
+
+        await self.db_pool.runInteraction(
+            "get_users_with_no_expiration_date",
+            select_users_with_no_expiration_date_txn,
+        )
+
+    def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
+        """Sets an expiration date to the account with the given user ID.
+
+        Args:
+             user_id (str): User ID to set an expiration date for.
+             use_delta (bool): If set to False, the expiration date for the user will be
+                now + validity period. If set to True, this expiration date will be a
+                random value in the [now + period - d ; now + period] range, d being a
+                delta equal to 10% of the validity period.
+        """
+        now_ms = self._clock.time_msec()
+        expiration_ts = now_ms + self._account_validity.period
+
+        if use_delta:
+            expiration_ts = self.rand.randrange(
+                expiration_ts - self._account_validity.startup_job_max_delta,
+                expiration_ts,
+            )
+
+        self.db_pool.simple_upsert_txn(
+            txn,
+            "account_validity",
+            keyvalues={"user_id": user_id},
+            values={"expiration_ts_ms": expiration_ts, "email_sent": False},
+        )
+
 
 class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
@@ -911,28 +1000,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
 
-        self._account_validity = hs.config.account_validity
         self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
 
-        if self._account_validity.enabled:
-            self._clock.call_later(
-                0.0,
-                run_as_background_process,
-                "account_validity_set_expiration_dates",
-                self._set_expiration_date_when_missing,
-            )
-
-        # Create a background job for culling expired 3PID validity tokens
-        def start_cull():
-            # run as a background process to make sure that the database transactions
-            # have a logcontext to report to
-            return run_as_background_process(
-                "cull_expired_threepid_validation_tokens",
-                self.cull_expired_threepid_validation_tokens,
-            )
-
-        hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
-
     async def add_access_token_to_user(
         self,
         user_id: str,
@@ -1477,22 +1546,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             start_or_continue_validation_session_txn,
         )
 
-    async def cull_expired_threepid_validation_tokens(self) -> None:
-        """Remove threepid validation tokens with expiry dates that have passed"""
-
-        def cull_expired_threepid_validation_tokens_txn(txn, ts):
-            sql = """
-            DELETE FROM threepid_validation_token WHERE
-            expires < ?
-            """
-            txn.execute(sql, (ts,))
-
-        await self.db_pool.runInteraction(
-            "cull_expired_threepid_validation_tokens",
-            cull_expired_threepid_validation_tokens_txn,
-            self.clock.time_msec(),
-        )
-
     async def set_user_deactivated_status(
         self, user_id: str, deactivated: bool
     ) -> None:
@@ -1522,61 +1575,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         )
         txn.call_after(self.is_guest.invalidate, (user_id,))
 
-    async def _set_expiration_date_when_missing(self):
-        """
-        Retrieves the list of registered users that don't have an expiration date, and
-        adds an expiration date for each of them.
-        """
-
-        def select_users_with_no_expiration_date_txn(txn):
-            """Retrieves the list of registered users with no expiration date from the
-            database, filtering out deactivated users.
-            """
-            sql = (
-                "SELECT users.name FROM users"
-                " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
-                " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
-            )
-            txn.execute(sql, [])
-
-            res = self.db_pool.cursor_to_dict(txn)
-            if res:
-                for user in res:
-                    self.set_expiration_date_for_user_txn(
-                        txn, user["name"], use_delta=True
-                    )
-
-        await self.db_pool.runInteraction(
-            "get_users_with_no_expiration_date",
-            select_users_with_no_expiration_date_txn,
-        )
-
-    def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
-        """Sets an expiration date to the account with the given user ID.
-
-        Args:
-             user_id (str): User ID to set an expiration date for.
-             use_delta (bool): If set to False, the expiration date for the user will be
-                now + validity period. If set to True, this expiration date will be a
-                random value in the [now + period - d ; now + period] range, d being a
-                delta equal to 10% of the validity period.
-        """
-        now_ms = self._clock.time_msec()
-        expiration_ts = now_ms + self._account_validity.period
-
-        if use_delta:
-            expiration_ts = self.rand.randrange(
-                expiration_ts - self._account_validity.startup_job_max_delta,
-                expiration_ts,
-            )
-
-        self.db_pool.simple_upsert_txn(
-            txn,
-            "account_validity",
-            keyvalues={"user_id": user_id},
-            values={"expiration_ts_ms": expiration_ts, "email_sent": False},
-        )
-
 
 def find_max_generated_user_id_localpart(cur: Cursor) -> int:
     """