diff --git a/changelog.d/8353.bugfix b/changelog.d/8353.bugfix
new file mode 100644
index 0000000000..45fc0adb8d
--- /dev/null
+++ b/changelog.d/8353.bugfix
@@ -0,0 +1 @@
+Don't send push notifications to expired user accounts.
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 75388643ee..1071a0576e 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -218,11 +218,7 @@ class Auth:
# Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired:
user_id = user.to_string()
- expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
- if (
- expiration_ts is not None
- and self.clock.time_msec() >= expiration_ts
- ):
+ if await self.store.is_account_expired(user_id, self.clock.time_msec()):
raise AuthError(
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
)
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index cc839ffce4..76150e117b 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -60,6 +60,8 @@ class PusherPool:
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
+ self._account_validity = hs.config.account_validity
+
# We shard the handling of push notifications by user ID.
self._pusher_shard_config = hs.config.push.pusher_shard_config
self._instance_name = hs.get_instance_name()
@@ -202,6 +204,14 @@ class PusherPool:
)
for u in users_affected:
+ # Don't push if the user account has expired
+ if self._account_validity.enabled:
+ expired = await self.store.is_account_expired(
+ u, self.clock.time_msec()
+ )
+ if expired:
+ continue
+
if u in self.pushers:
for p in self.pushers[u].values():
p.on_new_notifications(max_stream_id)
@@ -222,6 +232,14 @@ class PusherPool:
)
for u in users_affected:
+ # Don't push if the user account has expired
+ if self._account_validity.enabled:
+ expired = await self.store.is_account_expired(
+ u, self.clock.time_msec()
+ )
+ if expired:
+ continue
+
if u in self.pushers:
for p in self.pushers[u].values():
p.on_new_receipts(min_stream_id, max_stream_id)
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 675e81fe34..33825e8949 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -116,6 +116,20 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_expiration_ts_for_user",
)
+ async def is_account_expired(self, user_id: str, current_ts: int) -> bool:
+ """
+ Returns whether an user account is expired.
+
+ Args:
+ user_id: The user's ID
+ current_ts: The current timestamp
+
+ Returns:
+ Whether the user account has expired
+ """
+ expiration_ts = await self.get_expiration_ts_for_user(user_id)
+ return expiration_ts is not None and current_ts >= expiration_ts
+
async def set_account_validity_for_user(
self,
user_id: str,
|