summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorMathieu Velten <matmaul@gmail.com>2020-09-23 17:06:28 +0200
committerGitHub <noreply@github.com>2020-09-23 16:06:28 +0100
commit916bb9d0d15cf941e73b2e808c553a1edd1c2eb9 (patch)
tree9938653b43064cd6f66f6fb633f1f00cb24f1fd4 /synapse
parentFix missing null character check on guest_access room state (#8373) (diff)
downloadsynapse-916bb9d0d15cf941e73b2e808c553a1edd1c2eb9.tar.xz
Don't push if an user account has expired (#8353)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py6
-rw-r--r--synapse/push/pusherpool.py18
-rw-r--r--synapse/storage/databases/main/registration.py14
3 files changed, 33 insertions, 5 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 75388643ee..1071a0576e 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -218,11 +218,7 @@ class Auth:
             # Deny the request if the user account has expired.
             if self._account_validity.enabled and not allow_expired:
                 user_id = user.to_string()
-                expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
-                if (
-                    expiration_ts is not None
-                    and self.clock.time_msec() >= expiration_ts
-                ):
+                if await self.store.is_account_expired(user_id, self.clock.time_msec()):
                     raise AuthError(
                         403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
                     )
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index cc839ffce4..76150e117b 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -60,6 +60,8 @@ class PusherPool:
         self.store = self.hs.get_datastore()
         self.clock = self.hs.get_clock()
 
+        self._account_validity = hs.config.account_validity
+
         # We shard the handling of push notifications by user ID.
         self._pusher_shard_config = hs.config.push.pusher_shard_config
         self._instance_name = hs.get_instance_name()
@@ -202,6 +204,14 @@ class PusherPool:
             )
 
             for u in users_affected:
+                # Don't push if the user account has expired
+                if self._account_validity.enabled:
+                    expired = await self.store.is_account_expired(
+                        u, self.clock.time_msec()
+                    )
+                    if expired:
+                        continue
+
                 if u in self.pushers:
                     for p in self.pushers[u].values():
                         p.on_new_notifications(max_stream_id)
@@ -222,6 +232,14 @@ class PusherPool:
             )
 
             for u in users_affected:
+                # Don't push if the user account has expired
+                if self._account_validity.enabled:
+                    expired = await self.store.is_account_expired(
+                        u, self.clock.time_msec()
+                    )
+                    if expired:
+                        continue
+
                 if u in self.pushers:
                     for p in self.pushers[u].values():
                         p.on_new_receipts(min_stream_id, max_stream_id)
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 675e81fe34..33825e8949 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -116,6 +116,20 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="get_expiration_ts_for_user",
         )
 
+    async def is_account_expired(self, user_id: str, current_ts: int) -> bool:
+        """
+        Returns whether an user account is expired.
+
+        Args:
+            user_id: The user's ID
+            current_ts: The current timestamp
+
+        Returns:
+            Whether the user account has expired
+        """
+        expiration_ts = await self.get_expiration_ts_for_user(user_id)
+        return expiration_ts is not None and current_ts >= expiration_ts
+
     async def set_account_validity_for_user(
         self,
         user_id: str,