diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 01f20c03c2..a83df7759d 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -36,11 +36,14 @@ logger = logging.getLogger(__name__)
class RegistrationWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RegistrationWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.config = hs.config
self.clock = hs.get_clock()
+ # Note: we don't check this sequence for consistency as we'd have to
+ # call `find_max_generated_user_id_localpart` each time, which is
+ # expensive if there are many entries.
self._user_id_seq = build_sequence_generator(
database.engine, find_max_generated_user_id_localpart, "user_id_seq",
)
@@ -116,6 +119,20 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_expiration_ts_for_user",
)
+ async def is_account_expired(self, user_id: str, current_ts: int) -> bool:
+ """
+ Returns whether an user account is expired.
+
+ Args:
+ user_id: The user's ID
+ current_ts: The current timestamp
+
+ Returns:
+ Whether the user account has expired
+ """
+ expiration_ts = await self.get_expiration_ts_for_user(user_id)
+ return expiration_ts is not None and current_ts >= expiration_ts
+
async def set_account_validity_for_user(
self,
user_id: str,
@@ -379,7 +396,7 @@ class RegistrationWorkerStore(SQLBaseStore):
async def get_user_by_external_id(
self, auth_provider: str, external_id: str
- ) -> str:
+ ) -> Optional[str]:
"""Look up a user by their external auth id
Args:
@@ -387,7 +404,7 @@ class RegistrationWorkerStore(SQLBaseStore):
external_id: id on that system
Returns:
- str|None: the mxid of the user, or None if they are not known
+ the mxid of the user, or None if they are not known
"""
return await self.db_pool.simple_select_one_onecol(
table="user_external_ids",
@@ -764,7 +781,7 @@ class RegistrationWorkerStore(SQLBaseStore):
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.clock = hs.get_clock()
self.config = hs.config
@@ -892,7 +909,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
class RegistrationStore(RegistrationBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RegistrationStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self._account_validity = hs.config.account_validity
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
|