summary refs log tree commit diff
path: root/synapse/storage/databases/main/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/registration.py')
-rw-r--r--synapse/storage/databases/main/registration.py27
1 files changed, 22 insertions, 5 deletions
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 01f20c03c2..a83df7759d 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -36,11 +36,14 @@ logger = logging.getLogger(__name__)
 
 class RegistrationWorkerStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RegistrationWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.config = hs.config
         self.clock = hs.get_clock()
 
+        # Note: we don't check this sequence for consistency as we'd have to
+        # call `find_max_generated_user_id_localpart` each time, which is
+        # expensive if there are many entries.
         self._user_id_seq = build_sequence_generator(
             database.engine, find_max_generated_user_id_localpart, "user_id_seq",
         )
@@ -116,6 +119,20 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="get_expiration_ts_for_user",
         )
 
+    async def is_account_expired(self, user_id: str, current_ts: int) -> bool:
+        """
+        Returns whether an user account is expired.
+
+        Args:
+            user_id: The user's ID
+            current_ts: The current timestamp
+
+        Returns:
+            Whether the user account has expired
+        """
+        expiration_ts = await self.get_expiration_ts_for_user(user_id)
+        return expiration_ts is not None and current_ts >= expiration_ts
+
     async def set_account_validity_for_user(
         self,
         user_id: str,
@@ -379,7 +396,7 @@ class RegistrationWorkerStore(SQLBaseStore):
 
     async def get_user_by_external_id(
         self, auth_provider: str, external_id: str
-    ) -> str:
+    ) -> Optional[str]:
         """Look up a user by their external auth id
 
         Args:
@@ -387,7 +404,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             external_id: id on that system
 
         Returns:
-            str|None: the mxid of the user, or None if they are not known
+            the mxid of the user, or None if they are not known
         """
         return await self.db_pool.simple_select_one_onecol(
             table="user_external_ids",
@@ -764,7 +781,7 @@ class RegistrationWorkerStore(SQLBaseStore):
 
 class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.clock = hs.get_clock()
         self.config = hs.config
@@ -892,7 +909,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
 
 class RegistrationStore(RegistrationBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RegistrationStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self._account_validity = hs.config.account_validity
         self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors