diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 87e28e22d3..a6a1671bd6 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -47,6 +47,27 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+# The type of a row in the pushers table.
+PusherRow = Tuple[
+ int, # id
+ str, # user_name
+ Optional[int], # access_token
+ str, # profile_tag
+ str, # kind
+ str, # app_id
+ str, # app_display_name
+ str, # device_display_name
+ str, # pushkey
+ int, # ts
+ str, # lang
+ str, # data
+ int, # last_stream_ordering
+ int, # last_success
+ int, # failing_since
+ bool, # enabled
+ str, # device_id
+]
+
class PusherWorkerStore(SQLBaseStore):
def __init__(
@@ -83,30 +104,66 @@ class PusherWorkerStore(SQLBaseStore):
self._remove_deleted_email_pushers,
)
- def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
+ def _decode_pushers_rows(
+ self,
+ rows: Iterable[PusherRow],
+ ) -> Iterator[PusherConfig]:
"""JSON-decode the data in the rows returned from the `pushers` table
Drops any rows whose data cannot be decoded
"""
- for r in rows:
- data_json = r["data"]
+ for (
+ id,
+ user_name,
+ access_token,
+ profile_tag,
+ kind,
+ app_id,
+ app_display_name,
+ device_display_name,
+ pushkey,
+ ts,
+ lang,
+ data,
+ last_stream_ordering,
+ last_success,
+ failing_since,
+ enabled,
+ device_id,
+ ) in rows:
try:
- r["data"] = db_to_json(data_json)
+ data_json = db_to_json(data)
except Exception as e:
logger.warning(
"Invalid JSON in data for pusher %d: %s, %s",
- r["id"],
- data_json,
+ id,
+ data,
e.args[0],
)
continue
- # If we're using SQLite, then boolean values are integers. This is
- # troublesome since some code using the return value of this method might
- # expect it to be a boolean, or will expose it to clients (in responses).
- r["enabled"] = bool(r["enabled"])
-
- yield PusherConfig(**r)
+ yield PusherConfig(
+ id=id,
+ user_name=user_name,
+ profile_tag=profile_tag,
+ kind=kind,
+ app_id=app_id,
+ app_display_name=app_display_name,
+ device_display_name=device_display_name,
+ pushkey=pushkey,
+ ts=ts,
+ lang=lang,
+ data=data_json,
+ last_stream_ordering=last_stream_ordering,
+ last_success=last_success,
+ failing_since=failing_since,
+ # If we're using SQLite, then boolean values are integers. This is
+ # troublesome since some code using the return value of this method might
+ # expect it to be a boolean, or will expose it to clients (in responses).
+ enabled=bool(enabled),
+ device_id=device_id,
+ access_token=access_token,
+ )
def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
@@ -136,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore):
The pushers for which the given columns have the given values.
"""
- def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ def get_pushers_by_txn(txn: LoggingTransaction) -> List[PusherRow]:
# We could technically use simple_select_list here, but we need to call
# COALESCE on the 'enabled' column. While it is technically possible to give
# simple_select_list the whole `COALESCE(...) AS ...` as a column name, it
@@ -154,7 +211,7 @@ class PusherWorkerStore(SQLBaseStore):
txn.execute(sql, list(keyvalues.values()))
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[PusherRow], txn.fetchall())
ret = await self.db_pool.runInteraction(
desc="get_pushers_by",
@@ -164,14 +221,22 @@ class PusherWorkerStore(SQLBaseStore):
return self._decode_pushers_rows(ret)
async def get_enabled_pushers(self) -> Iterator[PusherConfig]:
- def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]:
- txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)")
- rows = self.db_pool.cursor_to_dict(txn)
-
- return self._decode_pushers_rows(rows)
+ def get_enabled_pushers_txn(txn: LoggingTransaction) -> List[PusherRow]:
+ txn.execute(
+ """
+ SELECT id, user_name, access_token, profile_tag, kind, app_id,
+ app_display_name, device_display_name, pushkey, ts, lang, data,
+ last_stream_ordering, last_success, failing_since,
+ enabled, device_id
+ FROM pushers WHERE COALESCE(enabled, TRUE)
+ """
+ )
+ return cast(List[PusherRow], txn.fetchall())
- return await self.db_pool.runInteraction(
- "get_enabled_pushers", get_enabled_pushers_txn
+ return self._decode_pushers_rows(
+ await self.db_pool.runInteraction(
+ "get_enabled_pushers", get_enabled_pushers_txn
+ )
)
async def get_all_updated_pushers_rows(
@@ -304,26 +369,28 @@ class PusherWorkerStore(SQLBaseStore):
)
async def get_throttle_params_by_room(
- self, pusher_id: str
+ self, pusher_id: int
) -> Dict[str, ThrottleParams]:
- res = await self.db_pool.simple_select_list(
- "pusher_throttle",
- {"pusher": pusher_id},
- ["room_id", "last_sent_ts", "throttle_ms"],
- desc="get_throttle_params_by_room",
+ res = cast(
+ List[Tuple[str, Optional[int], Optional[int]]],
+ await self.db_pool.simple_select_list(
+ "pusher_throttle",
+ {"pusher": pusher_id},
+ ["room_id", "last_sent_ts", "throttle_ms"],
+ desc="get_throttle_params_by_room",
+ ),
)
params_by_room = {}
- for row in res:
- params_by_room[row["room_id"]] = ThrottleParams(
- row["last_sent_ts"],
- row["throttle_ms"],
+ for room_id, last_sent_ts, throttle_ms in res:
+ params_by_room[room_id] = ThrottleParams(
+ last_sent_ts or 0, throttle_ms or 0
)
return params_by_room
async def set_throttle_params(
- self, pusher_id: str, room_id: str, params: ThrottleParams
+ self, pusher_id: int, room_id: str, params: ThrottleParams
) -> None:
await self.db_pool.simple_upsert(
"pusher_throttle",
@@ -534,7 +601,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
(last_pusher_id, batch_size),
)
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
if len(rows) == 0:
return 0
@@ -550,19 +617,19 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
txn=txn,
table="pushers",
key_names=("id",),
- key_values=[(row["pusher_id"],) for row in rows],
+ key_values=[row[0] for row in rows],
value_names=("device_id", "access_token"),
# If there was already a device_id on the pusher, we only want to clear
# the access_token column, so we keep the existing device_id. Otherwise,
# we set the device_id we got from joining the access_tokens table.
value_values=[
- (row["pusher_device_id"] or row["token_device_id"], None)
- for row in rows
+ (pusher_device_id or token_device_id, None)
+ for _, pusher_device_id, token_device_id in rows
],
)
self.db_pool.updates._background_update_progress_txn(
- txn, "set_device_id_for_pushers", {"pusher_id": rows[-1]["pusher_id"]}
+ txn, "set_device_id_for_pushers", {"pusher_id": rows[-1][0]}
)
return len(rows)
|