diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 933d76e905..dec9858575 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -425,17 +425,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
account timestamp as milliseconds since the epoch. None if the account
has not been renewed using the current token yet.
"""
- ret_dict = await self.db_pool.simple_select_one(
- table="account_validity",
- keyvalues={"renewal_token": renewal_token},
- retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
- desc="get_user_from_renewal_token",
- )
-
- return (
- ret_dict["user_id"],
- ret_dict["expiration_ts_ms"],
- ret_dict["token_used_ts_ms"],
+ return cast(
+ Tuple[str, int, Optional[int]],
+ await self.db_pool.simple_select_one(
+ table="account_validity",
+ keyvalues={"renewal_token": renewal_token},
+ retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
+ desc="get_user_from_renewal_token",
+ ),
)
async def get_renewal_token_for_user(self, user_id: str) -> str:
@@ -989,16 +986,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Returns:
user id, or None if no user id/threepid mapping exists
"""
- ret = self.db_pool.simple_select_one_txn(
+ return self.db_pool.simple_select_one_onecol_txn(
txn,
"user_threepids",
{"medium": medium, "address": address},
- ["user_id"],
+ "user_id",
True,
)
- if ret:
- return ret["user_id"]
- return None
async def user_add_threepid(
self,
@@ -1435,16 +1429,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
if res is None:
return False
+ uses_allowed, pending, completed, expiry_time = res
+
# Check if the token has expired
now = self._clock.time_msec()
- if res["expiry_time"] and res["expiry_time"] < now:
+ if expiry_time and expiry_time < now:
return False
# Check if the token has been used up
- if (
- res["uses_allowed"]
- and res["pending"] + res["completed"] >= res["uses_allowed"]
- ):
+ if uses_allowed and pending + completed >= uses_allowed:
return False
# Otherwise, the token is valid
@@ -1490,8 +1483,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# Override type because the return type is only optional if
# allow_none is True, and we don't want mypy throwing errors
# about None not being indexable.
- res = cast(
- Dict[str, Any],
+ pending, completed = cast(
+ Tuple[int, int],
self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
@@ -1506,8 +1499,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"registration_tokens",
keyvalues={"token": token},
updatevalues={
- "completed": res["completed"] + 1,
- "pending": res["pending"] - 1,
+ "completed": completed + 1,
+ "pending": pending - 1,
},
)
@@ -1585,13 +1578,22 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Returns:
A dict, or None if token doesn't exist.
"""
- return await self.db_pool.simple_select_one(
+ row = await self.db_pool.simple_select_one(
"registration_tokens",
keyvalues={"token": token},
retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
allow_none=True,
desc="get_one_registration_token",
)
+ if row is None:
+ return None
+ return {
+ "token": row[0],
+ "uses_allowed": row[1],
+ "pending": row[2],
+ "completed": row[3],
+ "expiry_time": row[4],
+ }
async def generate_registration_token(
self, length: int, chars: str
@@ -1714,7 +1716,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return None
# Get all info about the token so it can be sent in the response
- return self.db_pool.simple_select_one_txn(
+ result = self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
@@ -1728,6 +1730,17 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
allow_none=True,
)
+ if result is None:
+ return result
+
+ return {
+ "token": result[0],
+ "uses_allowed": result[1],
+ "pending": result[2],
+ "completed": result[3],
+ "expiry_time": result[4],
+ }
+
return await self.db_pool.runInteraction(
"update_registration_token", _update_registration_token_txn
)
@@ -1939,11 +1952,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
keyvalues={"token": token},
updatevalues={"used_ts": ts},
)
- user_id = values["user_id"]
- expiry_ts = values["expiry_ts"]
- used_ts = values["used_ts"]
- auth_provider_id = values["auth_provider_id"]
- auth_provider_session_id = values["auth_provider_session_id"]
+ (
+ user_id,
+ expiry_ts,
+ used_ts,
+ auth_provider_id,
+ auth_provider_session_id,
+ ) = values
# Token was already used
if used_ts is not None:
@@ -2756,12 +2771,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
# reason, the next check is on the client secret, which is NOT NULL,
# so we don't have to worry about the client secret matching by
# accident.
- row = {"client_secret": None, "validated_at": None}
+ row = None, None
else:
raise ThreepidValidationError("Unknown session_id")
- retrieved_client_secret = row["client_secret"]
- validated_at = row["validated_at"]
+ retrieved_client_secret, validated_at = row
row = self.db_pool.simple_select_one_txn(
txn,
@@ -2775,8 +2789,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
raise ThreepidValidationError(
"Validation token not found or has expired"
)
- expires = row["expires"]
- next_link = row["next_link"]
+ expires, next_link = row
if retrieved_client_secret != client_secret:
raise ThreepidValidationError(
|