diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index cc964604e2..64a2c31a5d 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -195,7 +195,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Returns info about the user account, if it exists."""
- def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
+ def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[UserInfo]:
# We could technically use simple_select_one here, but it would not perform
# the COALESCEs (unless hacked into the column names), which could yield
# confusing results.
@@ -213,35 +213,46 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
(user_id,),
)
- rows = self.db_pool.cursor_to_dict(txn)
-
- if len(rows) == 0:
+ row = txn.fetchone()
+ if not row:
return None
- return rows[0]
+ (
+ name,
+ is_guest,
+ admin,
+ consent_version,
+ consent_ts,
+ consent_server_notice_sent,
+ appservice_id,
+ creation_ts,
+ user_type,
+ deactivated,
+ shadow_banned,
+ approved,
+ locked,
+ ) = row
+
+ return UserInfo(
+ appservice_id=appservice_id,
+ consent_server_notice_sent=consent_server_notice_sent,
+ consent_version=consent_version,
+ consent_ts=consent_ts,
+ creation_ts=creation_ts,
+ is_admin=bool(admin),
+ is_deactivated=bool(deactivated),
+ is_guest=bool(is_guest),
+ is_shadow_banned=bool(shadow_banned),
+ user_id=UserID.from_string(name),
+ user_type=user_type,
+ approved=bool(approved),
+ locked=bool(locked),
+ )
- row = await self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
desc="get_user_by_id",
func=get_user_by_id_txn,
)
- if row is None:
- return None
-
- return UserInfo(
- appservice_id=row["appservice_id"],
- consent_server_notice_sent=row["consent_server_notice_sent"],
- consent_version=row["consent_version"],
- consent_ts=row["consent_ts"],
- creation_ts=row["creation_ts"],
- is_admin=bool(row["admin"]),
- is_deactivated=bool(row["deactivated"]),
- is_guest=bool(row["is_guest"]),
- is_shadow_banned=bool(row["shadow_banned"]),
- user_id=UserID.from_string(row["name"]),
- user_type=row["user_type"],
- approved=bool(row["approved"]),
- locked=bool(row["locked"]),
- )
async def is_trial_user(self, user_id: str) -> bool:
"""Checks if user is in the "trial" period, i.e. within the first
@@ -579,16 +590,31 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""
txn.execute(sql, (token,))
- rows = self.db_pool.cursor_to_dict(txn)
-
- if rows:
- row = rows[0]
-
- # This field is nullable, ensure it comes out as a boolean
- if row["token_used"] is None:
- row["token_used"] = False
+ row = txn.fetchone()
- return TokenLookupResult(**row)
+ if row:
+ (
+ user_id,
+ is_guest,
+ shadow_banned,
+ token_id,
+ device_id,
+ valid_until_ms,
+ token_owner,
+ token_used,
+ ) = row
+
+ return TokenLookupResult(
+ user_id=user_id,
+ is_guest=is_guest,
+ shadow_banned=shadow_banned,
+ token_id=token_id,
+ device_id=device_id,
+ valid_until_ms=valid_until_ms,
+ token_owner=token_owner,
+ # This field is nullable, ensure it comes out as a boolean
+ token_used=bool(token_used),
+ )
return None
@@ -833,11 +859,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""Counts all users registered on the homeserver."""
def _count_users(txn: LoggingTransaction) -> int:
- txn.execute("SELECT COUNT(*) AS users FROM users")
- rows = self.db_pool.cursor_to_dict(txn)
- if rows:
- return rows[0]["users"]
- return 0
+ txn.execute("SELECT COUNT(*) FROM users")
+ row = txn.fetchone()
+ assert row is not None
+ return row[0]
return await self.db_pool.runInteraction("count_users", _count_users)
@@ -891,11 +916,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""Counts all users without a special user_type registered on the homeserver."""
def _count_users(txn: LoggingTransaction) -> int:
- txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
- rows = self.db_pool.cursor_to_dict(txn)
- if rows:
- return rows[0]["users"]
- return 0
+ txn.execute("SELECT COUNT(*) FROM users where user_type is null")
+ row = txn.fetchone()
+ assert row is not None
+ return row[0]
return await self.db_pool.runInteraction("count_real_users", _count_users)
@@ -1252,12 +1276,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
txn.execute(sql, [])
- res = self.db_pool.cursor_to_dict(txn)
- if res:
- for user in res:
- self.set_expiration_date_for_user_txn(
- txn, user["name"], use_delta=True
- )
+ for (name,) in txn.fetchall():
+ self.set_expiration_date_for_user_txn(txn, name, use_delta=True)
await self.db_pool.runInteraction(
"get_users_with_no_expiration_date",
@@ -1963,11 +1983,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
(user_id,),
)
- rows = self.db_pool.cursor_to_dict(txn)
+ row = txn.fetchone()
+ assert row is not None
# We cast to bool because the value returned by the database engine might
# be an integer if we're using SQLite.
- return bool(rows[0]["approved"])
+ return bool(row[0])
return await self.db_pool.runInteraction(
desc="is_user_pending_approval",
@@ -2045,22 +2066,22 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
(last_user, batch_size),
)
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
if not rows:
return True, 0
rows_processed_nb = 0
- for user in rows:
- if not user["count_tokens"] and not user["count_threepids"]:
- self.set_user_deactivated_status_txn(txn, user["name"], True)
+ for name, count_tokens, count_threepids in rows:
+ if not count_tokens and not count_threepids:
+ self.set_user_deactivated_status_txn(txn, name, True)
rows_processed_nb += 1
logger.info("Marked %d rows as deactivated", rows_processed_nb)
self.db_pool.updates._background_update_progress_txn(
- txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
+ txn, "users_set_deactivated_flag", {"user_id": rows[-1][0]}
)
if batch_size > len(rows):
|