summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/config/registration.py11
-rw-r--r--synapse/storage/_base.py23
-rw-r--r--tests/rest/client/v2_alpha/test_register.py21
3 files changed, 53 insertions, 2 deletions
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 693288f938..b4fd4af368 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -39,6 +39,10 @@ class AccountValidityConfig(Config):
             else:
                 self.renew_email_subject = "Renew your %(app)s account"
 
+            self.startup_job_max_delta = self.parse_duration(
+                config.get("startup_job_max_delta", 0),
+            )
+
         if self.renew_by_email_enabled and "public_baseurl" not in synapse_config:
             raise ConfigError("Can't send renewal emails without 'public_baseurl'")
 
@@ -131,11 +135,18 @@ class RegistrationConfig(Config):
         # after that the validity period changes and Synapse is restarted, the users'
         # expiration dates won't be updated unless their account is manually renewed.
         #
+        # If set, the ``startup_job_max_delta`` optional setting will make the startup job
+        # described above set a random expiration date between t + period  and
+        # t + period + startup_job_max_delta, t being the date and time at which the job
+        # sets the expiration date for a given user. This is useful for server admins that
+        # want to avoid Synapse sending a lot of renewal emails at once.
+        #
         #account_validity:
         #  enabled: True
         #  period: 6w
         #  renew_at: 1w
         #  renew_email_subject: "Renew your %%(app)s account"
+        #  startup_job_max_delta: 2d
 
         # The user must provide all of the below types of 3PID when registering.
         #
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index fa6839ceca..40802fd3dc 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -16,6 +16,7 @@
 # limitations under the License.
 import itertools
 import logging
+import random
 import sys
 import threading
 import time
@@ -247,6 +248,8 @@ class SQLBaseStore(object):
                 self._check_safe_to_upsert,
             )
 
+        self.rand = random.SystemRandom()
+
         if self._account_validity.enabled:
             self._clock.call_later(
                 0.0,
@@ -308,21 +311,37 @@ class SQLBaseStore(object):
             res = self.cursor_to_dict(txn)
             if res:
                 for user in res:
-                    self.set_expiration_date_for_user_txn(txn, user["name"])
+                    self.set_expiration_date_for_user_txn(
+                        txn,
+                        user["name"],
+                        use_delta=True,
+                    )
 
         yield self.runInteraction(
             "get_users_with_no_expiration_date",
             select_users_with_no_expiration_date_txn,
         )
 
-    def set_expiration_date_for_user_txn(self, txn, user_id):
+    def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
         """Sets an expiration date to the account with the given user ID.
 
         Args:
              user_id (str): User ID to set an expiration date for.
+             use_delta (bool): If set to False, the expiration date for the user will be
+                now + validity period. If set to True, this expiration date will be a
+                random value in the [now + period; now + period + max_delta] range,
+                max_delta being the configured value for the size of the range, unless
+                delta is 0, in which case it sets it to now + period.
         """
         now_ms = self._clock.time_msec()
         expiration_ts = now_ms + self._account_validity.period
+
+        if use_delta and self._account_validity.startup_job_max_delta:
+            expiration_ts = self.rand.randrange(
+                expiration_ts,
+                expiration_ts + self._account_validity.startup_job_max_delta,
+            )
+
         self._simple_insert_txn(
             txn,
             "account_validity",
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index d4a1d4d50c..7603440fd8 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -436,6 +436,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
 
     def make_homeserver(self, reactor, clock):
         self.validity_period = 10
+        self.max_delta = 10
 
         config = self.default_config()
 
@@ -459,8 +460,28 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
         """
         user_id = self.register_user("kermit", "user")
 
+        self.hs.config.account_validity.startup_job_max_delta = 0
+
         now_ms = self.hs.clock.time_msec()
         self.get_success(self.store._set_expiration_date_when_missing())
 
         res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
         self.assertEqual(res, now_ms + self.validity_period)
+
+    def test_background_job_with_max_delta(self):
+        """
+        Tests the same thing as test_background_job, except that it sets the
+        startup_job_max_delta parameter and checks that the expiration date is within the
+        allowed range.
+        """
+        user_id = self.register_user("kermit_delta", "user")
+
+        self.hs.config.account_validity.startup_job_max_delta = self.max_delta
+
+        now_ms = self.hs.clock.time_msec()
+        self.get_success(self.store._set_expiration_date_when_missing())
+
+        res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
+
+        self.assertLessEqual(res, now_ms + self.validity_period + self.delta)
+        self.assertGreaterEqual(res, now_ms + self.validity_period)