diff options
Diffstat (limited to 'synapse/storage/databases/main/registration.py')
-rw-r--r-- | synapse/storage/databases/main/registration.py | 27 |
1 files changed, 22 insertions, 5 deletions
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 01f20c03c2..a83df7759d 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -36,11 +36,14 @@ logger = logging.getLogger(__name__) class RegistrationWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RegistrationWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.config = hs.config self.clock = hs.get_clock() + # Note: we don't check this sequence for consistency as we'd have to + # call `find_max_generated_user_id_localpart` each time, which is + # expensive if there are many entries. self._user_id_seq = build_sequence_generator( database.engine, find_max_generated_user_id_localpart, "user_id_seq", ) @@ -116,6 +119,20 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_expiration_ts_for_user", ) + async def is_account_expired(self, user_id: str, current_ts: int) -> bool: + """ + Returns whether an user account is expired. + + Args: + user_id: The user's ID + current_ts: The current timestamp + + Returns: + Whether the user account has expired + """ + expiration_ts = await self.get_expiration_ts_for_user(user_id) + return expiration_ts is not None and current_ts >= expiration_ts + async def set_account_validity_for_user( self, user_id: str, @@ -379,7 +396,7 @@ class RegistrationWorkerStore(SQLBaseStore): async def get_user_by_external_id( self, auth_provider: str, external_id: str - ) -> str: + ) -> Optional[str]: """Look up a user by their external auth id Args: @@ -387,7 +404,7 @@ class RegistrationWorkerStore(SQLBaseStore): external_id: id on that system Returns: - str|None: the mxid of the user, or None if they are not known + the mxid of the user, or None if they are not known """ return await self.db_pool.simple_select_one_onecol( table="user_external_ids", @@ -764,7 +781,7 @@ class RegistrationWorkerStore(SQLBaseStore): class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.clock = hs.get_clock() self.config = hs.config @@ -892,7 +909,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): class RegistrationStore(RegistrationBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RegistrationStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._account_validity = hs.config.account_validity self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors |