summary refs log tree commit diff
path: root/synapse/storage/databases/main/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/registration.py')
-rw-r--r--synapse/storage/databases/main/registration.py87
1 files changed, 50 insertions, 37 deletions
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py

index 933d76e905..dec9858575 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -425,17 +425,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): account timestamp as milliseconds since the epoch. None if the account has not been renewed using the current token yet. """ - ret_dict = await self.db_pool.simple_select_one( - table="account_validity", - keyvalues={"renewal_token": renewal_token}, - retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"], - desc="get_user_from_renewal_token", - ) - - return ( - ret_dict["user_id"], - ret_dict["expiration_ts_ms"], - ret_dict["token_used_ts_ms"], + return cast( + Tuple[str, int, Optional[int]], + await self.db_pool.simple_select_one( + table="account_validity", + keyvalues={"renewal_token": renewal_token}, + retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"], + desc="get_user_from_renewal_token", + ), ) async def get_renewal_token_for_user(self, user_id: str) -> str: @@ -989,16 +986,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): Returns: user id, or None if no user id/threepid mapping exists """ - ret = self.db_pool.simple_select_one_txn( + return self.db_pool.simple_select_one_onecol_txn( txn, "user_threepids", {"medium": medium, "address": address}, - ["user_id"], + "user_id", True, ) - if ret: - return ret["user_id"] - return None async def user_add_threepid( self, @@ -1435,16 +1429,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): if res is None: return False + uses_allowed, pending, completed, expiry_time = res + # Check if the token has expired now = self._clock.time_msec() - if res["expiry_time"] and res["expiry_time"] < now: + if expiry_time and expiry_time < now: return False # Check if the token has been used up - if ( - res["uses_allowed"] - and res["pending"] + res["completed"] >= res["uses_allowed"] - ): + if uses_allowed and pending + completed >= uses_allowed: return False # Otherwise, the token is valid @@ -1490,8 +1483,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): # Override type because the return type is only optional if # allow_none is True, and we don't want mypy throwing errors # about None not being indexable. - res = cast( - Dict[str, Any], + pending, completed = cast( + Tuple[int, int], self.db_pool.simple_select_one_txn( txn, "registration_tokens", @@ -1506,8 +1499,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "registration_tokens", keyvalues={"token": token}, updatevalues={ - "completed": res["completed"] + 1, - "pending": res["pending"] - 1, + "completed": completed + 1, + "pending": pending - 1, }, ) @@ -1585,13 +1578,22 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): Returns: A dict, or None if token doesn't exist. """ - return await self.db_pool.simple_select_one( + row = await self.db_pool.simple_select_one( "registration_tokens", keyvalues={"token": token}, retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"], allow_none=True, desc="get_one_registration_token", ) + if row is None: + return None + return { + "token": row[0], + "uses_allowed": row[1], + "pending": row[2], + "completed": row[3], + "expiry_time": row[4], + } async def generate_registration_token( self, length: int, chars: str @@ -1714,7 +1716,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return None # Get all info about the token so it can be sent in the response - return self.db_pool.simple_select_one_txn( + result = self.db_pool.simple_select_one_txn( txn, "registration_tokens", keyvalues={"token": token}, @@ -1728,6 +1730,17 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): allow_none=True, ) + if result is None: + return result + + return { + "token": result[0], + "uses_allowed": result[1], + "pending": result[2], + "completed": result[3], + "expiry_time": result[4], + } + return await self.db_pool.runInteraction( "update_registration_token", _update_registration_token_txn ) @@ -1939,11 +1952,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): keyvalues={"token": token}, updatevalues={"used_ts": ts}, ) - user_id = values["user_id"] - expiry_ts = values["expiry_ts"] - used_ts = values["used_ts"] - auth_provider_id = values["auth_provider_id"] - auth_provider_session_id = values["auth_provider_session_id"] + ( + user_id, + expiry_ts, + used_ts, + auth_provider_id, + auth_provider_session_id, + ) = values # Token was already used if used_ts is not None: @@ -2756,12 +2771,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): # reason, the next check is on the client secret, which is NOT NULL, # so we don't have to worry about the client secret matching by # accident. - row = {"client_secret": None, "validated_at": None} + row = None, None else: raise ThreepidValidationError("Unknown session_id") - retrieved_client_secret = row["client_secret"] - validated_at = row["validated_at"] + retrieved_client_secret, validated_at = row row = self.db_pool.simple_select_one_txn( txn, @@ -2775,8 +2789,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): raise ThreepidValidationError( "Validation token not found or has expired" ) - expires = row["expires"] - next_link = row["next_link"] + expires, next_link = row if retrieved_client_secret != client_secret: raise ThreepidValidationError(