summary refs log tree commit diff
path: root/synapse/push
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/push')
-rw-r--r--synapse/push/pusherpool.py24
1 files changed, 8 insertions, 16 deletions
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 2519ad76db..85621f33ef 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -62,10 +62,6 @@ class PusherPool:
         self.store = self.hs.get_datastore()
         self.clock = self.hs.get_clock()
 
-        self._account_validity_enabled = (
-            hs.config.account_validity.account_validity_enabled
-        )
-
         # We shard the handling of push notifications by user ID.
         self._pusher_shard_config = hs.config.push.pusher_shard_config
         self._instance_name = hs.get_instance_name()
@@ -89,6 +85,8 @@ class PusherPool:
         # map from user id to app_id:pushkey to pusher
         self.pushers: Dict[str, Dict[str, Pusher]] = {}
 
+        self._account_validity_handler = hs.get_account_validity_handler()
+
     def start(self) -> None:
         """Starts the pushers off in a background process."""
         if not self._should_start_pushers:
@@ -238,12 +236,9 @@ class PusherPool:
 
             for u in users_affected:
                 # Don't push if the user account has expired
-                if self._account_validity_enabled:
-                    expired = await self.store.is_account_expired(
-                        u, self.clock.time_msec()
-                    )
-                    if expired:
-                        continue
+                expired = await self._account_validity_handler.is_user_expired(u)
+                if expired:
+                    continue
 
                 if u in self.pushers:
                     for p in self.pushers[u].values():
@@ -268,12 +263,9 @@ class PusherPool:
 
             for u in users_affected:
                 # Don't push if the user account has expired
-                if self._account_validity_enabled:
-                    expired = await self.store.is_account_expired(
-                        u, self.clock.time_msec()
-                    )
-                    if expired:
-                        continue
+                expired = await self._account_validity_handler.is_user_expired(u)
+                if expired:
+                    continue
 
                 if u in self.pushers:
                     for p in self.pushers[u].values():