diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index f38bedbbcd..8ab7c42c4a 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -169,13 +169,17 @@ class UIAuthWorkerStore(SQLBaseStore):
that auth-type.
"""
results = {}
- for row in await self.db_pool.simple_select_list(
- table="ui_auth_sessions_credentials",
- keyvalues={"session_id": session_id},
- retcols=("stage_type", "result"),
- desc="get_completed_ui_auth_stages",
- ):
- results[row["stage_type"]] = db_to_json(row["result"])
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="ui_auth_sessions_credentials",
+ keyvalues={"session_id": session_id},
+ retcols=("stage_type", "result"),
+ desc="get_completed_ui_auth_stages",
+ ),
+ )
+ for stage_type, result in rows:
+ results[stage_type] = db_to_json(result)
return results
@@ -295,13 +299,15 @@ class UIAuthWorkerStore(SQLBaseStore):
Returns:
List of user_agent/ip pairs
"""
- rows = await self.db_pool.simple_select_list(
- table="ui_auth_sessions_ips",
- keyvalues={"session_id": session_id},
- retcols=("user_agent", "ip"),
- desc="get_user_agents_ips_to_ui_auth_session",
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="ui_auth_sessions_ips",
+ keyvalues={"session_id": session_id},
+ retcols=("user_agent", "ip"),
+ desc="get_user_agents_ips_to_ui_auth_session",
+ ),
)
- return [(row["user_agent"], row["ip"]) for row in rows]
async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
"""
@@ -337,13 +343,16 @@ class UIAuthWorkerStore(SQLBaseStore):
# If a registration token was used, decrement the pending counter
# before deleting the session.
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="ui_auth_sessions_credentials",
- column="session_id",
- iterable=session_ids,
- keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
- retcols=["result"],
+ rows = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="ui_auth_sessions_credentials",
+ column="session_id",
+ iterable=session_ids,
+ keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
+ retcols=["result"],
+ ),
)
# Get the tokens used and how much pending needs to be decremented by.
@@ -353,23 +362,25 @@ class UIAuthWorkerStore(SQLBaseStore):
# registration token stage for that session will be True.
# If a token was used to authenticate, but registration was
# never completed, the result will be the token used.
- token = db_to_json(r["result"])
+ token = db_to_json(r[0])
if isinstance(token, str):
token_counts[token] = token_counts.get(token, 0) + 1
# Update the `pending` counters.
if len(token_counts) > 0:
- token_rows = self.db_pool.simple_select_many_txn(
- txn,
- table="registration_tokens",
- column="token",
- iterable=list(token_counts.keys()),
- keyvalues={},
- retcols=["token", "pending"],
+ token_rows = cast(
+ List[Tuple[str, int]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="registration_tokens",
+ column="token",
+ iterable=list(token_counts.keys()),
+ keyvalues={},
+ retcols=["token", "pending"],
+ ),
)
- for token_row in token_rows:
- token = token_row["token"]
- new_pending = token_row["pending"] - token_counts[token]
+ for token, pending in token_rows:
+ new_pending = pending - token_counts[token]
self.db_pool.simple_update_one_txn(
txn,
table="registration_tokens",
|