diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index def836054d..27946febff 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -31,7 +31,7 @@ from synapse.rest import admin
from synapse.rest.client import account, login, register, room
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from synapse.server import HomeServer
-from synapse.types import JsonDict
+from synapse.types import JsonDict, UserID
from synapse.util import Clock
from tests import unittest
@@ -1222,6 +1222,62 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_failures=[users[2]],
)
+ @unittest.override_config(
+ {
+ "use_account_validity_in_account_status": True,
+ }
+ )
+ def test_no_account_validity(self) -> None:
+ """Tests that if we decide to include account validity in the response but no
+ account validity 'is_user_expired' callback is provided, we default to marking all
+ users as not expired.
+ """
+ user = self.register_user("someuser", "password")
+
+ self._test_status(
+ users=[user],
+ expected_statuses={
+ user: {
+ "exists": True,
+ "deactivated": False,
+ "org.matrix.expired": False,
+ },
+ },
+ expected_failures=[],
+ )
+
+ @unittest.override_config(
+ {
+ "use_account_validity_in_account_status": True,
+ }
+ )
+ def test_account_validity_expired(self) -> None:
+ """Test that if we decide to include account validity in the response and the user
+ is expired, we return the correct info.
+ """
+ user = self.register_user("someuser", "password")
+
+ async def is_expired(user_id: str) -> bool:
+ # We can't blindly say everyone is expired, otherwise the request to get the
+ # account status will fail.
+ return UserID.from_string(user_id).localpart == "someuser"
+
+ self.hs.get_account_validity_handler()._is_user_expired_callbacks.append(
+ is_expired
+ )
+
+ self._test_status(
+ users=[user],
+ expected_statuses={
+ user: {
+ "exists": True,
+ "deactivated": False,
+ "org.matrix.expired": True,
+ },
+ },
+ expected_failures=[],
+ )
+
def _test_status(
self,
users: Optional[List[str]],
|