diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 9e5eb2a445..4d405f2a0c 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -101,7 +101,7 @@ if TYPE_CHECKING:
class PusherConfig:
"""Parameters necessary to configure a pusher."""
- id: Optional[str]
+ id: Optional[int]
user_name: str
profile_tag: str
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 80f146dd53..16c284807a 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -151,10 +151,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
sql += " AND content != '{}'"
txn.execute(sql, (user_id,))
- rows = self.db_pool.cursor_to_dict(txn)
return {
- row["account_data_type"]: db_to_json(row["content"]) for row in rows
+ account_data_type: db_to_json(content)
+ for account_data_type, content in txn
}
return await self.db_pool.runInteraction(
@@ -196,13 +196,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
sql += " AND content != '{}'"
txn.execute(sql, (user_id,))
- rows = self.db_pool.cursor_to_dict(txn)
by_room: Dict[str, Dict[str, JsonDict]] = {}
- for row in rows:
- room_data = by_room.setdefault(row["room_id"], {})
+ for room_id, account_data_type, content in txn:
+ room_data = by_room.setdefault(room_id, {})
- room_data[row["account_data_type"]] = db_to_json(row["content"])
+ room_data[account_data_type] = db_to_json(content)
return by_room
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 0553a0621a..073a99cd84 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -14,17 +14,7 @@
# limitations under the License.
import logging
import re
-from typing import (
- TYPE_CHECKING,
- Any,
- Dict,
- List,
- Optional,
- Pattern,
- Sequence,
- Tuple,
- cast,
-)
+from typing import TYPE_CHECKING, List, Optional, Pattern, Sequence, Tuple, cast
from synapse.appservice import (
ApplicationService,
@@ -353,21 +343,15 @@ class ApplicationServiceTransactionWorkerStore(
def _get_oldest_unsent_txn(
txn: LoggingTransaction,
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[Tuple[int, str]]:
# Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent)
txn.execute(
- "SELECT * FROM application_services_txns WHERE as_id=?"
+ "SELECT txn_id, event_ids FROM application_services_txns WHERE as_id=?"
" ORDER BY txn_id ASC LIMIT 1",
(service.id,),
)
- rows = self.db_pool.cursor_to_dict(txn)
- if not rows:
- return None
-
- entry = rows[0]
-
- return entry
+ return cast(Optional[Tuple[int, str]], txn.fetchone())
entry = await self.db_pool.runInteraction(
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
@@ -376,8 +360,9 @@ class ApplicationServiceTransactionWorkerStore(
if not entry:
return None
- event_ids = db_to_json(entry["event_ids"])
+ txn_id, event_ids_str = entry
+ event_ids = db_to_json(event_ids_str)
events = await self.get_events_as_list(event_ids)
# TODO: to-device messages, one-time key counts, device list summaries and unused
@@ -385,7 +370,7 @@ class ApplicationServiceTransactionWorkerStore(
# We likely want to populate those for reliability.
return AppServiceTransaction(
service=service,
- id=entry["txn_id"],
+ id=txn_id,
events=events,
ephemeral=[],
to_device_messages=[],
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index df596f35f9..9f3804a504 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1413,13 +1413,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
def get_devices_not_accessed_since_txn(
txn: LoggingTransaction,
- ) -> List[Dict[str, str]]:
+ ) -> List[Tuple[str, str]]:
sql = """
SELECT user_id, device_id
FROM devices WHERE last_seen < ? AND hidden = FALSE
"""
txn.execute(sql, (since_ms,))
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[Tuple[str, str]], txn.fetchall())
rows = await self.db_pool.runInteraction(
"get_devices_not_accessed_since",
@@ -1427,11 +1427,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
)
devices: Dict[str, List[str]] = {}
- for row in rows:
+ for user_id, device_id in rows:
# Remote devices are never stale from our point of view.
- if self.hs.is_mine_id(row["user_id"]):
- user_devices = devices.setdefault(row["user_id"], [])
- user_devices.append(row["device_id"])
+ if self.hs.is_mine_id(user_id):
+ user_devices = devices.setdefault(user_id, [])
+ user_devices.append(device_id)
return devices
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 89fac23f93..749ae54e20 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -921,14 +921,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
}
txn.execute(sql, params)
- rows = self.db_pool.cursor_to_dict(txn)
- for row in rows:
- user_id = row["user_id"]
- key_type = row["keytype"]
- key = db_to_json(row["keydata"])
+ for user_id, key_type, key_data, _ in txn:
user_keys = result.setdefault(user_id, {})
- user_keys[key_type] = key
+ user_keys[key_type] = db_to_json(key_data)
return result
@@ -988,13 +984,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
query_params.extend(item)
txn.execute(sql, query_params)
- rows = self.db_pool.cursor_to_dict(txn)
# and add the signatures to the appropriate keys
- for row in rows:
- key_id: str = row["key_id"]
- target_user_id: str = row["target_user_id"]
- target_device_id: str = row["target_device_id"]
+ for target_user_id, target_device_id, key_id, signature in txn:
key_type = devices[(target_user_id, target_device_id)]
# We need to copy everything, because the result may have come
# from the cache. dict.copy only does a shallow copy, so we
@@ -1012,13 +1004,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
].copy()
if from_user_id in signatures:
user_sigs = signatures[from_user_id] = signatures[from_user_id]
- user_sigs[key_id] = row["signature"]
+ user_sigs[key_id] = signature
else:
- signatures[from_user_id] = {key_id: row["signature"]}
+ signatures[from_user_id] = {key_id: signature}
else:
- target_user_key["signatures"] = {
- from_user_id: {key_id: row["signature"]}
- }
+ target_user_key["signatures"] = {from_user_id: {key_id: signature}}
return keys
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 790d058c43..d4dcdb898c 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1654,8 +1654,6 @@ class PersistEventsStore:
) -> None:
to_prefill = []
- rows = []
-
ev_map = {e.event_id: e for e, _ in events_and_contexts}
if not ev_map:
return
@@ -1676,10 +1674,9 @@ class PersistEventsStore:
)
txn.execute(sql + clause, args)
- rows = self.db_pool.cursor_to_dict(txn)
- for row in rows:
- event = ev_map[row["event_id"]]
- if not row["rejects"] and not row["redacts"]:
+ for event_id, redacts, rejects in txn:
+ event = ev_map[event_id]
+ if not rejects and not redacts:
to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
async def external_prefill() -> None:
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 194b4e031f..805c23f89f 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -434,13 +434,21 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
txn = db_conn.cursor()
txn.execute(sql, (PresenceState.OFFLINE,))
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
txn.close()
- for row in rows:
- row["currently_active"] = bool(row["currently_active"])
-
- return [UserPresenceState(**row) for row in rows]
+ return [
+ UserPresenceState(
+ user_id=user_id,
+ state=state,
+ last_active_ts=last_active_ts,
+ last_federation_update_ts=last_federation_update_ts,
+ last_user_sync_ts=last_user_sync_ts,
+ status_msg=status_msg,
+ currently_active=bool(currently_active),
+ )
+ for user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active in rows
+ ]
def take_presence_startup_info(self) -> List[UserPresenceState]:
active_on_startup = self._presence_on_startup
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 87e28e22d3..c7eb7fc478 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -47,6 +47,27 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+# The type of a row in the pushers table.
+PusherRow = Tuple[
+ int, # id
+ str, # user_name
+ Optional[int], # access_token
+ str, # profile_tag
+ str, # kind
+ str, # app_id
+ str, # app_display_name
+ str, # device_display_name
+ str, # pushkey
+ int, # ts
+ str, # lang
+ str, # data
+ int, # last_stream_ordering
+ int, # last_success
+ int, # failing_since
+ bool, # enabled
+ str, # device_id
+]
+
class PusherWorkerStore(SQLBaseStore):
def __init__(
@@ -83,30 +104,66 @@ class PusherWorkerStore(SQLBaseStore):
self._remove_deleted_email_pushers,
)
- def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
+ def _decode_pushers_rows(
+ self,
+ rows: Iterable[PusherRow],
+ ) -> Iterator[PusherConfig]:
"""JSON-decode the data in the rows returned from the `pushers` table
Drops any rows whose data cannot be decoded
"""
- for r in rows:
- data_json = r["data"]
+ for (
+ id,
+ user_name,
+ access_token,
+ profile_tag,
+ kind,
+ app_id,
+ app_display_name,
+ device_display_name,
+ pushkey,
+ ts,
+ lang,
+ data,
+ last_stream_ordering,
+ last_success,
+ failing_since,
+ enabled,
+ device_id,
+ ) in rows:
try:
- r["data"] = db_to_json(data_json)
+ data_json = db_to_json(data)
except Exception as e:
logger.warning(
"Invalid JSON in data for pusher %d: %s, %s",
- r["id"],
- data_json,
+ id,
+ data,
e.args[0],
)
continue
- # If we're using SQLite, then boolean values are integers. This is
- # troublesome since some code using the return value of this method might
- # expect it to be a boolean, or will expose it to clients (in responses).
- r["enabled"] = bool(r["enabled"])
-
- yield PusherConfig(**r)
+ yield PusherConfig(
+ id=id,
+ user_name=user_name,
+ profile_tag=profile_tag,
+ kind=kind,
+ app_id=app_id,
+ app_display_name=app_display_name,
+ device_display_name=device_display_name,
+ pushkey=pushkey,
+ ts=ts,
+ lang=lang,
+ data=data_json,
+ last_stream_ordering=last_stream_ordering,
+ last_success=last_success,
+ failing_since=failing_since,
+ # If we're using SQLite, then boolean values are integers. This is
+ # troublesome since some code using the return value of this method might
+ # expect it to be a boolean, or will expose it to clients (in responses).
+ enabled=bool(enabled),
+ device_id=device_id,
+ access_token=access_token,
+ )
def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
@@ -136,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore):
The pushers for which the given columns have the given values.
"""
- def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ def get_pushers_by_txn(txn: LoggingTransaction) -> List[PusherRow]:
# We could technically use simple_select_list here, but we need to call
# COALESCE on the 'enabled' column. While it is technically possible to give
# simple_select_list the whole `COALESCE(...) AS ...` as a column name, it
@@ -154,7 +211,7 @@ class PusherWorkerStore(SQLBaseStore):
txn.execute(sql, list(keyvalues.values()))
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[PusherRow], txn.fetchall())
ret = await self.db_pool.runInteraction(
desc="get_pushers_by",
@@ -164,14 +221,22 @@ class PusherWorkerStore(SQLBaseStore):
return self._decode_pushers_rows(ret)
async def get_enabled_pushers(self) -> Iterator[PusherConfig]:
- def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]:
- txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)")
- rows = self.db_pool.cursor_to_dict(txn)
-
- return self._decode_pushers_rows(rows)
+ def get_enabled_pushers_txn(txn: LoggingTransaction) -> List[PusherRow]:
+ txn.execute(
+ """
+ SELECT id, user_name, access_token, profile_tag, kind, app_id,
+ app_display_name, device_display_name, pushkey, ts, lang, data,
+ last_stream_ordering, last_success, failing_since,
+ enabled, device_id
+ FROM pushers WHERE COALESCE(enabled, TRUE)
+ """
+ )
+ return cast(List[PusherRow], txn.fetchall())
- return await self.db_pool.runInteraction(
- "get_enabled_pushers", get_enabled_pushers_txn
+ return self._decode_pushers_rows(
+ await self.db_pool.runInteraction(
+ "get_enabled_pushers", get_enabled_pushers_txn
+ )
)
async def get_all_updated_pushers_rows(
@@ -304,7 +369,7 @@ class PusherWorkerStore(SQLBaseStore):
)
async def get_throttle_params_by_room(
- self, pusher_id: str
+ self, pusher_id: int
) -> Dict[str, ThrottleParams]:
res = await self.db_pool.simple_select_list(
"pusher_throttle",
@@ -323,7 +388,7 @@ class PusherWorkerStore(SQLBaseStore):
return params_by_room
async def set_throttle_params(
- self, pusher_id: str, room_id: str, params: ThrottleParams
+ self, pusher_id: int, room_id: str, params: ThrottleParams
) -> None:
await self.db_pool.simple_upsert(
"pusher_throttle",
@@ -534,7 +599,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
(last_pusher_id, batch_size),
)
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
if len(rows) == 0:
return 0
@@ -550,19 +615,19 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
txn=txn,
table="pushers",
key_names=("id",),
- key_values=[(row["pusher_id"],) for row in rows],
+ key_values=[row[0] for row in rows],
value_names=("device_id", "access_token"),
# If there was already a device_id on the pusher, we only want to clear
# the access_token column, so we keep the existing device_id. Otherwise,
# we set the device_id we got from joining the access_tokens table.
value_values=[
- (row["pusher_device_id"] or row["token_device_id"], None)
- for row in rows
+ (pusher_device_id or token_device_id, None)
+ for _, pusher_device_id, token_device_id in rows
],
)
self.db_pool.updates._background_update_progress_txn(
- txn, "set_device_id_for_pushers", {"pusher_id": rows[-1]["pusher_id"]}
+ txn, "set_device_id_for_pushers", {"pusher_id": rows[-1][0]}
)
return len(rows)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 3bab1024ea..b2645ab43c 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -313,25 +313,25 @@ class ReceiptsWorkerStore(SQLBaseStore):
) -> Sequence[JsonMapping]:
"""See get_linearized_receipts_for_room"""
- def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]:
if from_key:
sql = (
- "SELECT * FROM receipts_linearized WHERE"
+ "SELECT receipt_type, user_id, event_id, data"
+ " FROM receipts_linearized WHERE"
" room_id = ? AND stream_id > ? AND stream_id <= ?"
)
txn.execute(sql, (room_id, from_key, to_key))
else:
sql = (
- "SELECT * FROM receipts_linearized WHERE"
+ "SELECT receipt_type, user_id, event_id, data"
+ " FROM receipts_linearized WHERE"
" room_id = ? AND stream_id <= ?"
)
txn.execute(sql, (room_id, to_key))
- rows = self.db_pool.cursor_to_dict(txn)
-
- return rows
+ return cast(List[Tuple[str, str, str, str]], txn.fetchall())
rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
@@ -339,10 +339,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
return []
content: JsonDict = {}
- for row in rows:
- content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
- row["user_id"]
- ] = db_to_json(row["data"])
+ for receipt_type, user_id, event_id, data in rows:
+ content.setdefault(event_id, {}).setdefault(receipt_type, {})[
+ user_id
+ ] = db_to_json(data)
return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}]
@@ -357,10 +357,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
if not room_ids:
return {}
- def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ def f(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[str, str, str, str, Optional[str], str]]:
if from_key:
sql = """
- SELECT * FROM receipts_linearized WHERE
+ SELECT room_id, receipt_type, user_id, event_id, thread_id, data
+ FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ? AND
"""
clause, args = make_in_list_sql_clause(
@@ -370,7 +373,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql + clause, [from_key, to_key] + list(args))
else:
sql = """
- SELECT * FROM receipts_linearized WHERE
+ SELECT room_id, receipt_type, user_id, event_id, thread_id, data
+ FROM receipts_linearized WHERE
stream_id <= ? AND
"""
@@ -380,29 +384,31 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql + clause, [to_key] + list(args))
- return self.db_pool.cursor_to_dict(txn)
+ return cast(
+ List[Tuple[str, str, str, str, Optional[str], str]], txn.fetchall()
+ )
txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
results: JsonDict = {}
- for row in txn_results:
+ for room_id, receipt_type, user_id, event_id, thread_id, data in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(
- row["room_id"],
- {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
+ room_id,
+ {"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
)
# The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
- event_entry = room_event["content"].setdefault(row["event_id"], {})
- receipt_type = event_entry.setdefault(row["receipt_type"], {})
+ event_entry = room_event["content"].setdefault(event_id, {})
+ receipt_type_dict = event_entry.setdefault(receipt_type, {})
- receipt_type[row["user_id"]] = db_to_json(row["data"])
- if row["thread_id"]:
- receipt_type[row["user_id"]]["thread_id"] = row["thread_id"]
+ receipt_type_dict[user_id] = db_to_json(data)
+ if thread_id:
+ receipt_type_dict[user_id]["thread_id"] = thread_id
results = {
room_id: [results[room_id]] if room_id in results else []
@@ -428,10 +434,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
A dictionary of roomids to a list of receipts.
"""
- def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]:
if from_key:
sql = """
- SELECT * FROM receipts_linearized WHERE
+ SELECT room_id, receipt_type, user_id, event_id, data
+ FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ?
ORDER BY stream_id DESC
LIMIT 100
@@ -439,7 +446,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, [from_key, to_key])
else:
sql = """
- SELECT * FROM receipts_linearized WHERE
+ SELECT room_id, receipt_type, user_id, event_id, data
+ FROM receipts_linearized WHERE
stream_id <= ?
ORDER BY stream_id DESC
LIMIT 100
@@ -447,27 +455,27 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, [to_key])
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[Tuple[str, str, str, str, str]], txn.fetchall())
txn_results = await self.db_pool.runInteraction(
"get_linearized_receipts_for_all_rooms", f
)
results: JsonDict = {}
- for row in txn_results:
+ for room_id, receipt_type, user_id, event_id, data in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(
- row["room_id"],
- {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
+ room_id,
+ {"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
)
# The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
- event_entry = room_event["content"].setdefault(row["event_id"], {})
- receipt_type = event_entry.setdefault(row["receipt_type"], {})
+ event_entry = room_event["content"].setdefault(event_id, {})
+ receipt_type_dict = event_entry.setdefault(receipt_type, {})
- receipt_type[row["user_id"]] = db_to_json(row["data"])
+ receipt_type_dict[user_id] = db_to_json(data)
return results
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index cc964604e2..64a2c31a5d 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -195,7 +195,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Returns info about the user account, if it exists."""
- def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
+ def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[UserInfo]:
# We could technically use simple_select_one here, but it would not perform
# the COALESCEs (unless hacked into the column names), which could yield
# confusing results.
@@ -213,35 +213,46 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
(user_id,),
)
- rows = self.db_pool.cursor_to_dict(txn)
-
- if len(rows) == 0:
+ row = txn.fetchone()
+ if not row:
return None
- return rows[0]
+ (
+ name,
+ is_guest,
+ admin,
+ consent_version,
+ consent_ts,
+ consent_server_notice_sent,
+ appservice_id,
+ creation_ts,
+ user_type,
+ deactivated,
+ shadow_banned,
+ approved,
+ locked,
+ ) = row
+
+ return UserInfo(
+ appservice_id=appservice_id,
+ consent_server_notice_sent=consent_server_notice_sent,
+ consent_version=consent_version,
+ consent_ts=consent_ts,
+ creation_ts=creation_ts,
+ is_admin=bool(admin),
+ is_deactivated=bool(deactivated),
+ is_guest=bool(is_guest),
+ is_shadow_banned=bool(shadow_banned),
+ user_id=UserID.from_string(name),
+ user_type=user_type,
+ approved=bool(approved),
+ locked=bool(locked),
+ )
- row = await self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
desc="get_user_by_id",
func=get_user_by_id_txn,
)
- if row is None:
- return None
-
- return UserInfo(
- appservice_id=row["appservice_id"],
- consent_server_notice_sent=row["consent_server_notice_sent"],
- consent_version=row["consent_version"],
- consent_ts=row["consent_ts"],
- creation_ts=row["creation_ts"],
- is_admin=bool(row["admin"]),
- is_deactivated=bool(row["deactivated"]),
- is_guest=bool(row["is_guest"]),
- is_shadow_banned=bool(row["shadow_banned"]),
- user_id=UserID.from_string(row["name"]),
- user_type=row["user_type"],
- approved=bool(row["approved"]),
- locked=bool(row["locked"]),
- )
async def is_trial_user(self, user_id: str) -> bool:
"""Checks if user is in the "trial" period, i.e. within the first
@@ -579,16 +590,31 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""
txn.execute(sql, (token,))
- rows = self.db_pool.cursor_to_dict(txn)
-
- if rows:
- row = rows[0]
-
- # This field is nullable, ensure it comes out as a boolean
- if row["token_used"] is None:
- row["token_used"] = False
+ row = txn.fetchone()
- return TokenLookupResult(**row)
+ if row:
+ (
+ user_id,
+ is_guest,
+ shadow_banned,
+ token_id,
+ device_id,
+ valid_until_ms,
+ token_owner,
+ token_used,
+ ) = row
+
+ return TokenLookupResult(
+ user_id=user_id,
+ is_guest=is_guest,
+ shadow_banned=shadow_banned,
+ token_id=token_id,
+ device_id=device_id,
+ valid_until_ms=valid_until_ms,
+ token_owner=token_owner,
+ # This field is nullable, ensure it comes out as a boolean
+ token_used=bool(token_used),
+ )
return None
@@ -833,11 +859,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""Counts all users registered on the homeserver."""
def _count_users(txn: LoggingTransaction) -> int:
- txn.execute("SELECT COUNT(*) AS users FROM users")
- rows = self.db_pool.cursor_to_dict(txn)
- if rows:
- return rows[0]["users"]
- return 0
+ txn.execute("SELECT COUNT(*) FROM users")
+ row = txn.fetchone()
+ assert row is not None
+ return row[0]
return await self.db_pool.runInteraction("count_users", _count_users)
@@ -891,11 +916,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""Counts all users without a special user_type registered on the homeserver."""
def _count_users(txn: LoggingTransaction) -> int:
- txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
- rows = self.db_pool.cursor_to_dict(txn)
- if rows:
- return rows[0]["users"]
- return 0
+ txn.execute("SELECT COUNT(*) FROM users where user_type is null")
+ row = txn.fetchone()
+ assert row is not None
+ return row[0]
return await self.db_pool.runInteraction("count_real_users", _count_users)
@@ -1252,12 +1276,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
txn.execute(sql, [])
- res = self.db_pool.cursor_to_dict(txn)
- if res:
- for user in res:
- self.set_expiration_date_for_user_txn(
- txn, user["name"], use_delta=True
- )
+ for (name,) in txn.fetchall():
+ self.set_expiration_date_for_user_txn(txn, name, use_delta=True)
await self.db_pool.runInteraction(
"get_users_with_no_expiration_date",
@@ -1963,11 +1983,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
(user_id,),
)
- rows = self.db_pool.cursor_to_dict(txn)
+ row = txn.fetchone()
+ assert row is not None
# We cast to bool because the value returned by the database engine might
# be an integer if we're using SQLite.
- return bool(rows[0]["approved"])
+ return bool(row[0])
return await self.db_pool.runInteraction(
desc="is_user_pending_approval",
@@ -2045,22 +2066,22 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
(last_user, batch_size),
)
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
if not rows:
return True, 0
rows_processed_nb = 0
- for user in rows:
- if not user["count_tokens"] and not user["count_threepids"]:
- self.set_user_deactivated_status_txn(txn, user["name"], True)
+ for name, count_tokens, count_threepids in rows:
+ if not count_tokens and not count_threepids:
+ self.set_user_deactivated_status_txn(txn, name, True)
rows_processed_nb += 1
logger.info("Marked %d rows as deactivated", rows_processed_nb)
self.db_pool.updates._background_update_progress_txn(
- txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
+ txn, "users_set_deactivated_flag", {"user_id": rows[-1][0]}
)
if batch_size > len(rows):
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 719e11aea6..1d4d99932b 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -831,7 +831,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def get_retention_policy_for_room_txn(
txn: LoggingTransaction,
- ) -> List[Dict[str, Optional[int]]]:
+ ) -> Optional[Tuple[Optional[int], Optional[int]]]:
txn.execute(
"""
SELECT min_lifetime, max_lifetime FROM room_retention
@@ -841,7 +841,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
(room_id,),
)
- return self.db_pool.cursor_to_dict(txn)
+ return cast(Optional[Tuple[Optional[int], Optional[int]]], txn.fetchone())
ret = await self.db_pool.runInteraction(
"get_retention_policy_for_room",
@@ -856,8 +856,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
max_lifetime=self.config.retention.retention_default_max_lifetime,
)
- min_lifetime = ret[0]["min_lifetime"]
- max_lifetime = ret[0]["max_lifetime"]
+ min_lifetime, max_lifetime = ret
# If one of the room's policy's attributes isn't defined, use the matching
# attribute from the default policy.
@@ -1162,14 +1161,13 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
txn.execute(sql, args)
- rows = self.db_pool.cursor_to_dict(txn)
- rooms_dict = {}
-
- for row in rows:
- rooms_dict[row["room_id"]] = RetentionPolicy(
- min_lifetime=row["min_lifetime"],
- max_lifetime=row["max_lifetime"],
+ rooms_dict = {
+ room_id: RetentionPolicy(
+ min_lifetime=min_lifetime,
+ max_lifetime=max_lifetime,
)
+ for room_id, min_lifetime, max_lifetime in txn
+ }
if include_null:
# If required, do a second query that retrieves all of the rooms we know
@@ -1178,13 +1176,11 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
txn.execute(sql)
- rows = self.db_pool.cursor_to_dict(txn)
-
# If a room isn't already in the dict (i.e. it doesn't have a retention
# policy in its state), add it with a null policy.
- for row in rows:
- if row["room_id"] not in rooms_dict:
- rooms_dict[row["room_id"]] = RetentionPolicy()
+ for (room_id,) in txn:
+ if room_id not in rooms_dict:
+ rooms_dict[room_id] = RetentionPolicy()
return rooms_dict
@@ -1703,24 +1699,24 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
(last_room, batch_size),
)
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
if not rows:
return True
- for row in rows:
- if not row["json"]:
+ for room_id, event_id, json in rows:
+ if not json:
retention_policy = {}
else:
- ev = db_to_json(row["json"])
+ ev = db_to_json(json)
retention_policy = ev["content"]
self.db_pool.simple_insert_txn(
txn=txn,
table="room_retention",
values={
- "room_id": row["room_id"],
- "event_id": row["event_id"],
+ "room_id": room_id,
+ "event_id": event_id,
"min_lifetime": retention_policy.get("min_lifetime"),
"max_lifetime": retention_policy.get("max_lifetime"),
},
@@ -1729,7 +1725,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
logger.info("Inserted %d rows into room_retention", len(rows))
self.db_pool.updates._background_update_progress_txn(
- txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]}
+ txn, "insert_room_retention", {"room_id": rows[-1][0]}
)
if batch_size > len(rows):
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index e93573f315..bbe08368db 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -1349,18 +1349,16 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
if not rows:
return 0
- min_stream_id = rows[-1]["stream_ordering"]
+ min_stream_id = rows[-1][0]
to_update = []
- for row in rows:
- event_id = row["event_id"]
- room_id = row["room_id"]
+ for _, event_id, room_id, json in rows:
try:
- event_json = db_to_json(row["json"])
+ event_json = db_to_json(json)
content = event_json["content"]
except Exception:
continue
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index a7aae661d8..1d69c4a5f0 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -179,22 +179,24 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
# store_search_entries_txn with a generator function, but that
# would mean having two cursors open on the database at once.
# Instead we just build a list of results.
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
if not rows:
return 0
- min_stream_id = rows[-1]["stream_ordering"]
+ min_stream_id = rows[-1][0]
event_search_rows = []
- for row in rows:
+ for (
+ stream_ordering,
+ event_id,
+ room_id,
+ etype,
+ json,
+ origin_server_ts,
+ ) in rows:
try:
- event_id = row["event_id"]
- room_id = row["room_id"]
- etype = row["type"]
- stream_ordering = row["stream_ordering"]
- origin_server_ts = row["origin_server_ts"]
try:
- event_json = db_to_json(row["json"])
+ event_json = db_to_json(json)
content = event_json["content"]
except Exception:
continue
diff --git a/synapse/storage/databases/main/task_scheduler.py b/synapse/storage/databases/main/task_scheduler.py
index 5c5372a825..5555b53575 100644
--- a/synapse/storage/databases/main/task_scheduler.py
+++ b/synapse/storage/databases/main/task_scheduler.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, List, Optional, Tuple, cast
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
@@ -27,6 +27,8 @@ from synapse.util import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
+ScheduledTaskRow = Tuple[str, str, str, int, str, str, str, str]
+
class TaskSchedulerWorkerStore(SQLBaseStore):
def __init__(
@@ -38,13 +40,18 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
super().__init__(database, db_conn, hs)
@staticmethod
- def _convert_row_to_task(row: Dict[str, Any]) -> ScheduledTask:
- row["status"] = TaskStatus(row["status"])
- if row["params"] is not None:
- row["params"] = db_to_json(row["params"])
- if row["result"] is not None:
- row["result"] = db_to_json(row["result"])
- return ScheduledTask(**row)
+ def _convert_row_to_task(row: ScheduledTaskRow) -> ScheduledTask:
+ task_id, action, status, timestamp, resource_id, params, result, error = row
+ return ScheduledTask(
+ id=task_id,
+ action=action,
+ status=TaskStatus(status),
+ timestamp=timestamp,
+ resource_id=resource_id,
+ params=db_to_json(params) if params is not None else None,
+ result=db_to_json(result) if result is not None else None,
+ error=error,
+ )
async def get_scheduled_tasks(
self,
@@ -68,7 +75,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
Returns: a list of `ScheduledTask`, ordered by increasing timestamps
"""
- def get_scheduled_tasks_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ def get_scheduled_tasks_txn(txn: LoggingTransaction) -> List[ScheduledTaskRow]:
clauses: List[str] = []
args: List[Any] = []
if resource_id:
@@ -101,7 +108,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
args.append(limit)
txn.execute(sql, args)
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[ScheduledTaskRow], txn.fetchall())
rows = await self.db_pool.runInteraction(
"get_scheduled_tasks", get_scheduled_tasks_txn
@@ -193,7 +200,22 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
desc="get_scheduled_task",
)
- return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None
+ return (
+ TaskSchedulerWorkerStore._convert_row_to_task(
+ (
+ row["id"],
+ row["action"],
+ row["status"],
+ row["timestamp"],
+ row["resource_id"],
+ row["params"],
+ row["result"],
+ row["error"],
+ )
+ )
+ if row
+ else None
+ )
async def delete_scheduled_task(self, id: str) -> None:
"""Delete a specific task from its id.
|