summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2019-05-17 19:37:31 +0100
committerBrendan Abolivier <babolivier@matrix.org>2019-05-17 19:37:31 +0100
commitad5b4074e1a84b1e50a97144fbf1902ddc89db73 (patch)
tree4e953ee49049dbdafda8e942d4b41e81d5e30ee2 /synapse/storage/_base.py
parentLint (diff)
downloadsynapse-ad5b4074e1a84b1e50a97144fbf1902ddc89db73.tar.xz
Add startup background job for account validity
If account validity is enabled in the server's configuration, this job will run at startup as a background job and will stick an expiration date to any registered account missing one.
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py62
1 files changed, 62 insertions, 0 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 983ce026e1..06d516c9e2 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -1,5 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -227,6 +229,8 @@ class SQLBaseStore(object):
         # A set of tables that are not safe to use native upserts in.
         self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
 
+        self._account_validity = self.hs.config.account_validity
+
         # We add the user_directory_search table to the blacklist on SQLite
         # because the existing search table does not have an index, making it
         # unsafe to use native upserts.
@@ -243,6 +247,14 @@ class SQLBaseStore(object):
                 self._check_safe_to_upsert,
             )
 
+        if self._account_validity.enabled:
+            self._clock.call_later(
+                0.0,
+                run_as_background_process,
+                "account_validity_set_expiration_dates",
+                self._set_expiration_date_when_missing,
+            )
+
     @defer.inlineCallbacks
     def _check_safe_to_upsert(self):
         """
@@ -275,6 +287,56 @@ class SQLBaseStore(object):
                 self._check_safe_to_upsert,
             )
 
+    @defer.inlineCallbacks
+    def _set_expiration_date_when_missing(self):
+        """
+        Retrieves the list of registered users that don't have an expiration date, and
+        adds an expiration date for each of them.
+        """
+
+        def select_users_with_no_expiration_date_txn(txn):
+            """Retrieves the list of registered users with no expiration date from the
+            database.
+            """
+            sql = (
+                "SELECT users.name FROM users"
+                " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
+                " WHERE account_validity.user_id is NULL;"
+            )
+            txn.execute(sql, [])
+            return self.cursor_to_dict(txn)
+
+        res = yield self.runInteraction(
+            "get_users_with_no_expiration_date",
+            select_users_with_no_expiration_date_txn,
+        )
+
+        if res:
+            for user in res:
+                self.runInteraction(
+                    "set_expiration_date_for_user_background",
+                    self.set_expiration_date_for_user_txn,
+                    user["name"],
+                )
+
+    def set_expiration_date_for_user_txn(self, txn, user_id):
+        """Sets an expiration date to the account with the given user ID.
+
+        Args:
+             user_id (str): User ID to set an expiration date for.
+        """
+        now_ms = self._clock.time_msec()
+        expiration_ts = now_ms + self._account_validity.period
+        self._simple_insert_txn(
+            txn,
+            "account_validity",
+            values={
+                "user_id": user_id,
+                "expiration_ts_ms": expiration_ts,
+                "email_sent": False,
+            },
+        )
+
     def start_profiling(self):
         self._previous_loop_ts = self._clock.time_msec()