summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12266.misc1
-rw-r--r--synapse/config/server.py4
-rw-r--r--synapse/handlers/account.py11
-rw-r--r--tests/rest/client/test_account.py58
4 files changed, 73 insertions, 1 deletions
diff --git a/changelog.d/12266.misc b/changelog.d/12266.misc
new file mode 100644
index 0000000000..59e2718370
--- /dev/null
+++ b/changelog.d/12266.misc
@@ -0,0 +1 @@
+Optionally include account validity expiration information to experimental [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) account status responses.
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 49cd0a4f19..38de4b8000 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -676,6 +676,10 @@ class ServerConfig(Config):
         ):
             raise ConfigError("'custom_template_directory' must be a string")
 
+        self.use_account_validity_in_account_status: bool = (
+            config.get("use_account_validity_in_account_status") or False
+        )
+
     def has_tls_listener(self) -> bool:
         return any(listener.tls for listener in self.listeners)
 
diff --git a/synapse/handlers/account.py b/synapse/handlers/account.py
index d5badf635b..c05a14304c 100644
--- a/synapse/handlers/account.py
+++ b/synapse/handlers/account.py
@@ -26,6 +26,10 @@ class AccountHandler:
         self._main_store = hs.get_datastores().main
         self._is_mine = hs.is_mine
         self._federation_client = hs.get_federation_client()
+        self._use_account_validity_in_account_status = (
+            hs.config.server.use_account_validity_in_account_status
+        )
+        self._account_validity_handler = hs.get_account_validity_handler()
 
     async def get_account_statuses(
         self,
@@ -106,6 +110,13 @@ class AccountHandler:
                 "deactivated": userinfo.is_deactivated,
             }
 
+            if self._use_account_validity_in_account_status:
+                status[
+                    "org.matrix.expired"
+                ] = await self._account_validity_handler.is_user_expired(
+                    user_id.to_string()
+                )
+
         return status
 
     async def _get_remote_account_statuses(
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index def836054d..27946febff 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -31,7 +31,7 @@ from synapse.rest import admin
 from synapse.rest.client import account, login, register, room
 from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
 from synapse.server import HomeServer
-from synapse.types import JsonDict
+from synapse.types import JsonDict, UserID
 from synapse.util import Clock
 
 from tests import unittest
@@ -1222,6 +1222,62 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
             expected_failures=[users[2]],
         )
 
+    @unittest.override_config(
+        {
+            "use_account_validity_in_account_status": True,
+        }
+    )
+    def test_no_account_validity(self) -> None:
+        """Tests that if we decide to include account validity in the response but no
+        account validity 'is_user_expired' callback is provided, we default to marking all
+        users as not expired.
+        """
+        user = self.register_user("someuser", "password")
+
+        self._test_status(
+            users=[user],
+            expected_statuses={
+                user: {
+                    "exists": True,
+                    "deactivated": False,
+                    "org.matrix.expired": False,
+                },
+            },
+            expected_failures=[],
+        )
+
+    @unittest.override_config(
+        {
+            "use_account_validity_in_account_status": True,
+        }
+    )
+    def test_account_validity_expired(self) -> None:
+        """Test that if we decide to include account validity in the response and the user
+        is expired, we return the correct info.
+        """
+        user = self.register_user("someuser", "password")
+
+        async def is_expired(user_id: str) -> bool:
+            # We can't blindly say everyone is expired, otherwise the request to get the
+            # account status will fail.
+            return UserID.from_string(user_id).localpart == "someuser"
+
+        self.hs.get_account_validity_handler()._is_user_expired_callbacks.append(
+            is_expired
+        )
+
+        self._test_status(
+            users=[user],
+            expected_statuses={
+                user: {
+                    "exists": True,
+                    "deactivated": False,
+                    "org.matrix.expired": True,
+                },
+            },
+            expected_failures=[],
+        )
+
     def _test_status(
         self,
         users: Optional[List[str]],