summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-10-21 10:48:44 +0100
committerErik Johnston <erik@matrix.org>2020-10-21 10:50:15 +0100
commit8620b271134f82514cb2db23d4f77504a8e63e6e (patch)
tree93cbd2e234379be8ef14aa47f4b39cd7b3bb9844
parentAdd registration store to mypy (diff)
downloadsynapse-8620b271134f82514cb2db23d4f77504a8e63e6e.tar.xz
Add typing info to registration
-rw-r--r--synapse/storage/databases/main/registration.py46
1 files changed, 24 insertions, 22 deletions
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py

index d856b76952..f2b7233d29 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -16,7 +16,7 @@ # limitations under the License. import logging import re -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError @@ -30,17 +30,19 @@ from synapse.storage.util.sequence import build_sequence_generator from synapse.types import UserID from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 logger = logging.getLogger(__name__) class RegistrationWorkerStore(CacheInvalidationWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.config = hs.config - self.clock = hs.get_clock() # Note: we don't check this sequence for consistency as we'd have to # call `find_max_generated_user_id_localpart` each time, which is @@ -57,7 +59,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): # Create a background job for culling expired 3PID validity tokens if hs.config.run_background_tasks: - self.clock.looping_call( + self._clock.looping_call( self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS ) @@ -94,7 +96,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): if not info: return False - now = self.clock.time_msec() + now = self._clock.time_msec() trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000 is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms return is_trial @@ -259,7 +261,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return await self.db_pool.runInteraction( "get_users_expiring_soon", select_users_txn, - self.clock.time_msec(), + self._clock.time_msec(), self.config.account_validity.renew_at, ) @@ -809,7 +811,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): await self.db_pool.runInteraction( "cull_expired_threepid_validation_tokens", cull_expired_threepid_validation_tokens_txn, - self.clock.time_msec(), + self._clock.time_msec(), ) @wrap_as_background_process("account_validity_set_expiration_dates") @@ -896,10 +898,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) - self.clock = hs.get_clock() + self._clock = hs.get_clock() self.config = hs.config self.db_pool.updates.register_background_index_update( @@ -1039,7 +1041,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): deactivated, ) - def set_user_deactivated_status_txn(self, txn, user_id, deactivated): + def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool): self.db_pool.simple_update_one_txn( txn=txn, table="users", @@ -1065,7 +1067,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors @@ -1187,19 +1189,19 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): def _register_user( self, txn, - user_id, - password_hash, - was_guest, - make_guest, - appservice_id, - create_profile_with_displayname, - admin, - user_type, - shadow_banned, + user_id: str, + password_hash: Optional[str], + was_guest: bool, + make_guest: bool, + appservice_id: Optional[str], + create_profile_with_displayname: Optional[str], + admin: bool, + user_type: Optional[str], + shadow_banned: bool, ): user_id_obj = UserID.from_string(user_id) - now = int(self.clock.time()) + now = int(self._clock.time()) try: if was_guest: @@ -1516,7 +1518,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, - updatevalues={"validated_at": self.clock.time_msec()}, + updatevalues={"validated_at": self._clock.time_msec()}, ) return next_link