summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/api/auth.py4
-rw-r--r--synapse/config/account_validity.py31
-rw-r--r--synapse/handlers/account_validity.py19
-rw-r--r--synapse/handlers/deactivate_account.py2
-rw-r--r--synapse/push/pusherpool.py6
-rw-r--r--synapse/rest/client/v2_alpha/account_validity.py4
-rw-r--r--synapse/storage/databases/main/registration.py17
7 files changed, 50 insertions, 33 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index bfcaf68b2a..5d45ffcc0b 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -76,7 +76,7 @@ class Auth:
 
         self._auth_blocking = AuthBlocking(self.hs)
 
-        self._account_validity = hs.config.account_validity
+        self._account_validity_enabled = hs.config.account_validity_enabled
         self._track_appservice_user_ips = hs.config.track_appservice_user_ips
         self._macaroon_secret_key = hs.config.macaroon_secret_key
 
@@ -219,7 +219,7 @@ class Auth:
             shadow_banned = user_info.shadow_banned
 
             # Deny the request if the user account has expired.
-            if self._account_validity.enabled and not allow_expired:
+            if self._account_validity_enabled and not allow_expired:
                 if await self.store.is_account_expired(
                     user_info.user_id, self.clock.time_msec()
                 ):
diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py
index 66545d717c..34df7f87f4 100644
--- a/synapse/config/account_validity.py
+++ b/synapse/config/account_validity.py
@@ -23,26 +23,31 @@ class AccountValidityConfig(Config):
     section = "accountvalidity"
 
     def read_config(self, config, **kwargs):
-        self.enabled = config.get("enabled", False)
-        self.renew_by_email_enabled = "renew_at" in config
+        account_validity_config = config.get("account_validity", {})
+        self.account_validity_enabled = account_validity_config.get("enabled", False)
+        self.account_validity_renew_by_email_enabled = (
+            "renew_at" in account_validity_config
+        )
 
-        if self.enabled:
+        if self.account_validity_enabled:
             if "period" in config:
-                self.period = self.parse_duration(config["period"])
+                self.account_validity_period = self.parse_duration(config["period"])
             else:
                 raise ConfigError("'period' is required when using account validity")
 
             if "renew_at" in config:
-                self.renew_at = self.parse_duration(config["renew_at"])
+                self.account_validity_renew_at = self.parse_duration(config["renew_at"])
 
             if "renew_email_subject" in config:
-                self.renew_email_subject = config["renew_email_subject"]
+                self.account_validity_renew_email_subject = config[
+                    "renew_email_subject"
+                ]
             else:
-                self.renew_email_subject = "Renew your %(app)s account"
+                self.account_validity_renew_email_subject = "Renew your %(app)s account"
 
-            self.startup_job_max_delta = self.period * 10.0 / 100.0
+            self.account_validity_startup_job_max_delta = self.period * 10.0 / 100.0
 
-        if self.renew_by_email_enabled:
+        if self.account_validity_renew_by_email_enabled:
             if not self.public_baseurl:
                 raise ConfigError("Can't send renewal emails without 'public_baseurl'")
 
@@ -54,22 +59,22 @@ class AccountValidityConfig(Config):
         if "account_renewed_html_path" in config:
             file_path = os.path.join(template_dir, config["account_renewed_html_path"])
 
-            self.account_renewed_html_content = self.read_file(
+            self.account_validity_account_renewed_html_content = self.read_file(
                 file_path, "account_validity.account_renewed_html_path"
             )
         else:
-            self.account_renewed_html_content = (
+            self.account_validity_account_renewed_html_content = (
                 "<html><body>Your account has been successfully renewed.</body><html>"
             )
 
         if "invalid_token_html_path" in config:
             file_path = os.path.join(template_dir, config["invalid_token_html_path"])
 
-            self.invalid_token_html_content = self.read_file(
+            self.account_validity_invalid_token_html_content = self.read_file(
                 file_path, "account_validity.invalid_token_html_path"
             )
         else:
-            self.invalid_token_html_content = (
+            self.account_validity_invalid_token_html_content = (
                 "<html><body>Invalid renewal token.</body><html>"
             )
 
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 664d09da1c..65eb0735fd 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -40,11 +40,18 @@ class AccountValidityHandler:
         self.sendmail = self.hs.get_sendmail()
         self.clock = self.hs.get_clock()
 
-        self._account_validity = self.hs.config.account_validity
+        self._account_validity_period = self.hs.config.account_validity_period
+        self._account_validity_enabled = self.hs.config.account_validity_enabled
+        self._account_validity_renew_email_subject = (
+            self.hs.config.account_validity_renew_email_subject
+        )
+        self._account_validity_renew_by_email_enabled = (
+            self.hs.config.account_validity_renew_by_email_enabled
+        )
 
         if (
-            self._account_validity.enabled
-            and self._account_validity.renew_by_email_enabled
+            self._account_validity_enabled
+            and self._account_validity_renew_by_email_enabled
         ):
             # Don't do email-specific configuration if renewal by email is disabled.
             self._template_html = self.config.account_validity_template_html
@@ -53,14 +60,14 @@ class AccountValidityHandler:
             try:
                 app_name = self.hs.config.email_app_name
 
-                self._subject = self._account_validity.renew_email_subject % {
+                self._subject = self._account_validity_renew_email_subject % {
                     "app": app_name
                 }
 
                 self._from_string = self.hs.config.email_notif_from % {"app": app_name}
             except Exception:
                 # If substitution failed, fall back to the bare strings.
-                self._subject = self._account_validity.renew_email_subject
+                self._subject = self._account_validity_renew_email_subject
                 self._from_string = self.hs.config.email_notif_from
 
             self._raw_from = email.utils.parseaddr(self._from_string)[1]
@@ -258,7 +265,7 @@ class AccountValidityHandler:
             milliseconds since epoch.
         """
         if expiration_ts is None:
-            expiration_ts = self.clock.time_msec() + self._account_validity.period
+            expiration_ts = self.clock.time_msec() + self._account_validity_period
 
         await self.store.set_account_validity_for_user(
             user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index e808142365..b2f965093f 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -49,7 +49,7 @@ class DeactivateAccountHandler(BaseHandler):
         if hs.config.run_background_tasks:
             hs.get_reactor().callWhenRunning(self._start_user_parting)
 
-        self._account_validity_enabled = hs.config.account_validity.enabled
+        self._account_validity_enabled = hs.config.account_validity_enabled
 
     async def deactivate_account(
         self, user_id: str, erase_data: bool, id_server: Optional[str] = None
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 9fcc0b8a64..748b1407c6 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -62,7 +62,7 @@ class PusherPool:
         self.store = self.hs.get_datastore()
         self.clock = self.hs.get_clock()
 
-        self._account_validity = hs.config.account_validity
+        self._account_validity_enabled = hs.config.account_validity_enabled
 
         # We shard the handling of push notifications by user ID.
         self._pusher_shard_config = hs.config.push.pusher_shard_config
@@ -223,7 +223,7 @@ class PusherPool:
 
             for u in users_affected:
                 # Don't push if the user account has expired
-                if self._account_validity.enabled:
+                if self._account_validity_enabled:
                     expired = await self.store.is_account_expired(
                         u, self.clock.time_msec()
                     )
@@ -251,7 +251,7 @@ class PusherPool:
 
             for u in users_affected:
                 # Don't push if the user account has expired
-                if self._account_validity.enabled:
+                if self._account_validity_enabled:
                     expired = await self.store.is_account_expired(
                         u, self.clock.time_msec()
                     )
diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
index bd7f9ae203..c9761f05ae 100644
--- a/synapse/rest/client/v2_alpha/account_validity.py
+++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -37,8 +37,8 @@ class AccountValidityRenewServlet(RestServlet):
         self.hs = hs
         self.account_activity_handler = hs.get_account_validity_handler()
         self.auth = hs.get_auth()
-        self.success_html = hs.config.account_validity.account_renewed_html_content
-        self.failure_html = hs.config.account_validity.invalid_token_html_content
+        self.success_html = hs.config.account_validity_account_renewed_html_content
+        self.failure_html = hs.config.account_validity_invalid_token_html_content
 
     async def on_GET(self, request):
         if b"token" not in request.args:
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index ff96c34c2e..0f1e88031d 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -82,8 +82,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             database.engine, find_max_generated_user_id_localpart, "user_id_seq",
         )
 
-        self._account_validity = hs.config.account_validity
-        if hs.config.run_background_tasks and self._account_validity.enabled:
+        self._account_validity_enabled = hs.config.account_validity_enabled
+        self._account_validity_period = hs.config.account_validity_period
+        self._account_validity_renew_at = hs.config.account_validity_renew_at
+        self._account_validity_startup_job_max_delta = (
+            hs.config.account_validity_startup_job_max_delta
+        )
+        if hs.config.run_background_tasks and self._account_validity_enabled:
             self._clock.call_later(
                 0.0, self._set_expiration_date_when_missing,
             )
@@ -291,7 +296,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             "get_users_expiring_soon",
             select_users_txn,
             self._clock.time_msec(),
-            self.config.account_validity.renew_at,
+            self.config.account_validity_renew_at,
         )
 
     async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None:
@@ -902,11 +907,11 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 delta equal to 10% of the validity period.
         """
         now_ms = self._clock.time_msec()
-        expiration_ts = now_ms + self._account_validity.period
+        expiration_ts = now_ms + self._account_validity_period
 
         if use_delta:
             expiration_ts = self.rand.randrange(
-                expiration_ts - self._account_validity.startup_job_max_delta,
+                expiration_ts - self._account_validity_startup_job_max_delta,
                 expiration_ts,
             )
 
@@ -1306,7 +1311,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
         except self.database_engine.module.IntegrityError:
             raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
 
-        if self._account_validity.enabled:
+        if self._account_validity_enabled:
             self.set_expiration_date_for_user_txn(txn, user_id)
 
         if create_profile_with_displayname: