summary refs log tree commit diff
path: root/synapse/storage/databases/main/pusher.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/pusher.py')
-rw-r--r--synapse/storage/databases/main/pusher.py121
1 files changed, 93 insertions, 28 deletions
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 87e28e22d3..c7eb7fc478 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -47,6 +47,27 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
+# The type of a row in the pushers table.
+PusherRow = Tuple[
+    int,  # id
+    str,  # user_name
+    Optional[int],  # access_token
+    str,  # profile_tag
+    str,  # kind
+    str,  # app_id
+    str,  # app_display_name
+    str,  # device_display_name
+    str,  # pushkey
+    int,  # ts
+    str,  # lang
+    str,  # data
+    int,  # last_stream_ordering
+    int,  # last_success
+    int,  # failing_since
+    bool,  # enabled
+    str,  # device_id
+]
+
 
 class PusherWorkerStore(SQLBaseStore):
     def __init__(
@@ -83,30 +104,66 @@ class PusherWorkerStore(SQLBaseStore):
             self._remove_deleted_email_pushers,
         )
 
-    def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
+    def _decode_pushers_rows(
+        self,
+        rows: Iterable[PusherRow],
+    ) -> Iterator[PusherConfig]:
         """JSON-decode the data in the rows returned from the `pushers` table
 
         Drops any rows whose data cannot be decoded
         """
-        for r in rows:
-            data_json = r["data"]
+        for (
+            id,
+            user_name,
+            access_token,
+            profile_tag,
+            kind,
+            app_id,
+            app_display_name,
+            device_display_name,
+            pushkey,
+            ts,
+            lang,
+            data,
+            last_stream_ordering,
+            last_success,
+            failing_since,
+            enabled,
+            device_id,
+        ) in rows:
             try:
-                r["data"] = db_to_json(data_json)
+                data_json = db_to_json(data)
             except Exception as e:
                 logger.warning(
                     "Invalid JSON in data for pusher %d: %s, %s",
-                    r["id"],
-                    data_json,
+                    id,
+                    data,
                     e.args[0],
                 )
                 continue
 
-            # If we're using SQLite, then boolean values are integers. This is
-            # troublesome since some code using the return value of this method might
-            # expect it to be a boolean, or will expose it to clients (in responses).
-            r["enabled"] = bool(r["enabled"])
-
-            yield PusherConfig(**r)
+            yield PusherConfig(
+                id=id,
+                user_name=user_name,
+                profile_tag=profile_tag,
+                kind=kind,
+                app_id=app_id,
+                app_display_name=app_display_name,
+                device_display_name=device_display_name,
+                pushkey=pushkey,
+                ts=ts,
+                lang=lang,
+                data=data_json,
+                last_stream_ordering=last_stream_ordering,
+                last_success=last_success,
+                failing_since=failing_since,
+                # If we're using SQLite, then boolean values are integers. This is
+                # troublesome since some code using the return value of this method might
+                # expect it to be a boolean, or will expose it to clients (in responses).
+                enabled=bool(enabled),
+                device_id=device_id,
+                access_token=access_token,
+            )
 
     def get_pushers_stream_token(self) -> int:
         return self._pushers_id_gen.get_current_token()
@@ -136,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore):
             The pushers for which the given columns have the given values.
         """
 
-        def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+        def get_pushers_by_txn(txn: LoggingTransaction) -> List[PusherRow]:
             # We could technically use simple_select_list here, but we need to call
             # COALESCE on the 'enabled' column. While it is technically possible to give
             # simple_select_list the whole `COALESCE(...) AS ...` as a column name, it
@@ -154,7 +211,7 @@ class PusherWorkerStore(SQLBaseStore):
 
             txn.execute(sql, list(keyvalues.values()))
 
-            return self.db_pool.cursor_to_dict(txn)
+            return cast(List[PusherRow], txn.fetchall())
 
         ret = await self.db_pool.runInteraction(
             desc="get_pushers_by",
@@ -164,14 +221,22 @@ class PusherWorkerStore(SQLBaseStore):
         return self._decode_pushers_rows(ret)
 
     async def get_enabled_pushers(self) -> Iterator[PusherConfig]:
-        def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]:
-            txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)")
-            rows = self.db_pool.cursor_to_dict(txn)
-
-            return self._decode_pushers_rows(rows)
+        def get_enabled_pushers_txn(txn: LoggingTransaction) -> List[PusherRow]:
+            txn.execute(
+                """
+                SELECT id, user_name, access_token, profile_tag, kind, app_id,
+                    app_display_name, device_display_name, pushkey, ts, lang, data,
+                    last_stream_ordering, last_success, failing_since,
+                    enabled, device_id
+                FROM pushers WHERE COALESCE(enabled, TRUE)
+                """
+            )
+            return cast(List[PusherRow], txn.fetchall())
 
-        return await self.db_pool.runInteraction(
-            "get_enabled_pushers", get_enabled_pushers_txn
+        return self._decode_pushers_rows(
+            await self.db_pool.runInteraction(
+                "get_enabled_pushers", get_enabled_pushers_txn
+            )
         )
 
     async def get_all_updated_pushers_rows(
@@ -304,7 +369,7 @@ class PusherWorkerStore(SQLBaseStore):
         )
 
     async def get_throttle_params_by_room(
-        self, pusher_id: str
+        self, pusher_id: int
     ) -> Dict[str, ThrottleParams]:
         res = await self.db_pool.simple_select_list(
             "pusher_throttle",
@@ -323,7 +388,7 @@ class PusherWorkerStore(SQLBaseStore):
         return params_by_room
 
     async def set_throttle_params(
-        self, pusher_id: str, room_id: str, params: ThrottleParams
+        self, pusher_id: int, room_id: str, params: ThrottleParams
     ) -> None:
         await self.db_pool.simple_upsert(
             "pusher_throttle",
@@ -534,7 +599,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
                 (last_pusher_id, batch_size),
             )
 
-            rows = self.db_pool.cursor_to_dict(txn)
+            rows = txn.fetchall()
             if len(rows) == 0:
                 return 0
 
@@ -550,19 +615,19 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
                 txn=txn,
                 table="pushers",
                 key_names=("id",),
-                key_values=[(row["pusher_id"],) for row in rows],
+                key_values=[row[0] for row in rows],
                 value_names=("device_id", "access_token"),
                 # If there was already a device_id on the pusher, we only want to clear
                 # the access_token column, so we keep the existing device_id. Otherwise,
                 # we set the device_id we got from joining the access_tokens table.
                 value_values=[
-                    (row["pusher_device_id"] or row["token_device_id"], None)
-                    for row in rows
+                    (pusher_device_id or token_device_id, None)
+                    for _, pusher_device_id, token_device_id in rows
                 ],
             )
 
             self.db_pool.updates._background_update_progress_txn(
-                txn, "set_device_id_for_pushers", {"pusher_id": rows[-1]["pusher_id"]}
+                txn, "set_device_id_for_pushers", {"pusher_id": rows[-1][0]}
             )
 
             return len(rows)