diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index d856b76952..f2b7233d29 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -16,7 +16,7 @@
# limitations under the License.
import logging
import re
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -30,17 +30,19 @@ from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID
from synapse.util.caches.descriptors import cached
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
logger = logging.getLogger(__name__)
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.config = hs.config
- self.clock = hs.get_clock()
# Note: we don't check this sequence for consistency as we'd have to
# call `find_max_generated_user_id_localpart` each time, which is
@@ -57,7 +59,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# Create a background job for culling expired 3PID validity tokens
if hs.config.run_background_tasks:
- self.clock.looping_call(
+ self._clock.looping_call(
self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
)
@@ -94,7 +96,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
if not info:
return False
- now = self.clock.time_msec()
+ now = self._clock.time_msec()
trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
return is_trial
@@ -259,7 +261,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return await self.db_pool.runInteraction(
"get_users_expiring_soon",
select_users_txn,
- self.clock.time_msec(),
+ self._clock.time_msec(),
self.config.account_validity.renew_at,
)
@@ -809,7 +811,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction(
"cull_expired_threepid_validation_tokens",
cull_expired_threepid_validation_tokens_txn,
- self.clock.time_msec(),
+ self._clock.time_msec(),
)
@wrap_as_background_process("account_validity_set_expiration_dates")
@@ -896,10 +898,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
- self.clock = hs.get_clock()
+ self._clock = hs.get_clock()
self.config = hs.config
self.db_pool.updates.register_background_index_update(
@@ -1039,7 +1041,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
deactivated,
)
- def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
+ def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool):
self.db_pool.simple_update_one_txn(
txn=txn,
table="users",
@@ -1065,7 +1067,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
@@ -1187,19 +1189,19 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def _register_user(
self,
txn,
- user_id,
- password_hash,
- was_guest,
- make_guest,
- appservice_id,
- create_profile_with_displayname,
- admin,
- user_type,
- shadow_banned,
+ user_id: str,
+ password_hash: Optional[str],
+ was_guest: bool,
+ make_guest: bool,
+ appservice_id: Optional[str],
+ create_profile_with_displayname: Optional[str],
+ admin: bool,
+ user_type: Optional[str],
+ shadow_banned: bool,
):
user_id_obj = UserID.from_string(user_id)
- now = int(self.clock.time())
+ now = int(self._clock.time())
try:
if was_guest:
@@ -1516,7 +1518,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
- updatevalues={"validated_at": self.clock.time_msec()},
+ updatevalues={"validated_at": self._clock.time_msec()},
)
return next_link
|