diff --git a/changelog.d/58.misc b/changelog.d/58.misc
new file mode 100644
index 0000000000..64098a68a4
--- /dev/null
+++ b/changelog.d/58.misc
@@ -0,0 +1 @@
+Don't push if an user account has expired.
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 1bbb7b607f..71859c9bc0 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -222,11 +222,10 @@ class Auth(object):
# Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired:
user_id = user.to_string()
- expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
- if (
- expiration_ts is not None
- and self.clock.time_msec() >= expiration_ts
- ):
+ expired = yield self.store.is_account_expired(
+ user_id, self.clock.time_msec()
+ )
+ if expired:
raise AuthError(
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
)
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 2456f12f46..7c5e47bc81 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -66,6 +66,8 @@ class PusherPool:
self._pusher_shard_config = hs.config.push.pusher_shard_config
self._instance_name = hs.get_instance_name()
+ self._account_validity = hs.config.account_validity
+
# map from user id to app_id:pushkey to pusher
self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]]
@@ -196,6 +198,14 @@ class PusherPool:
for u in users_affected:
if u in self.pushers:
+ # Don't push if the user account has expired
+ if self._account_validity.enabled:
+ expired = yield self.store.is_account_expired(
+ u, self.clock.time_msec()
+ )
+ if expired:
+ continue
+
for p in self.pushers[u].values():
p.on_new_notifications(min_stream_id, max_stream_id)
@@ -217,6 +227,14 @@ class PusherPool:
for u in users_affected:
if u in self.pushers:
+ # Don't push if the user account has expired
+ if self._account_validity.enabled:
+ expired = yield self.store.is_account_expired(
+ u, self.clock.time_msec()
+ )
+ if expired:
+ continue
+
for p in self.pushers[u].values():
p.on_new_receipts(min_stream_id, max_stream_id)
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index e1b6cded65..7930041998 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -126,6 +126,23 @@ class RegistrationWorkerStore(SQLBaseStore):
return res
@defer.inlineCallbacks
+ def is_account_expired(self, user_id: str, current_ts: int):
+ """
+ Returns whether an user account is expired.
+
+ Args:
+ user_id: The user's ID
+ current_ts: The current timestamp
+
+ Returns:
+ Deferred[bool]: whether the user account has expired
+ """
+ expiration_ts = yield self.get_expiration_ts_for_user(user_id)
+ if expiration_ts is not None and current_ts >= expiration_ts:
+ return True
+ return False
+
+ @defer.inlineCallbacks
def set_account_validity_for_user(
self, user_id, expiration_ts, email_sent, renewal_token=None
):
|