diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index cf64cd63a4..91286c9b65 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -14,11 +14,25 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ Optional,
+ Tuple,
+ cast,
+)
from synapse.push import PusherConfig, ThrottleParams
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
@@ -117,7 +131,7 @@ class PusherWorkerStore(SQLBaseStore):
return self._decode_pushers_rows(ret)
async def get_all_pushers(self) -> Iterator[PusherConfig]:
- def get_pushers(txn):
+ def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]:
txn.execute("SELECT * FROM pushers")
rows = self.db_pool.cursor_to_dict(txn)
@@ -152,7 +166,9 @@ class PusherWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
- def get_all_updated_pushers_rows_txn(txn):
+ def get_all_updated_pushers_rows_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
sql = """
SELECT id, user_name, app_id, pushkey
FROM pushers
@@ -160,10 +176,13 @@ class PusherWorkerStore(SQLBaseStore):
ORDER BY id ASC LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
- updates = [
- (stream_id, (user_name, app_id, pushkey, False))
- for stream_id, user_name, app_id, pushkey in txn
- ]
+ updates = cast(
+ List[Tuple[int, tuple]],
+ [
+ (stream_id, (user_name, app_id, pushkey, False))
+ for stream_id, user_name, app_id, pushkey in txn
+ ],
+ )
sql = """
SELECT stream_id, user_id, app_id, pushkey
@@ -192,12 +211,12 @@ class PusherWorkerStore(SQLBaseStore):
)
@cached(num_args=1, max_entries=15000)
- async def get_if_user_has_pusher(self, user_id: str):
+ async def get_if_user_has_pusher(self, user_id: str) -> None:
# This only exists for the cachedList decorator
raise NotImplementedError()
async def update_pusher_last_stream_ordering(
- self, app_id, pushkey, user_id, last_stream_ordering
+ self, app_id: str, pushkey: str, user_id: str, last_stream_ordering: int
) -> None:
await self.db_pool.simple_update_one(
"pushers",
@@ -291,7 +310,7 @@ class PusherWorkerStore(SQLBaseStore):
last_user = progress.get("last_user", "")
- def _delete_pushers(txn) -> int:
+ def _delete_pushers(txn: LoggingTransaction) -> int:
sql = """
SELECT name FROM users
@@ -339,7 +358,7 @@ class PusherWorkerStore(SQLBaseStore):
last_pusher = progress.get("last_pusher", 0)
- def _delete_pushers(txn) -> int:
+ def _delete_pushers(txn: LoggingTransaction) -> int:
sql = """
SELECT p.id, access_token FROM pushers AS p
@@ -396,7 +415,7 @@ class PusherWorkerStore(SQLBaseStore):
last_pusher = progress.get("last_pusher", 0)
- def _delete_pushers(txn) -> int:
+ def _delete_pushers(txn: LoggingTransaction) -> int:
sql = """
SELECT p.id, p.user_name, p.app_id, p.pushkey
@@ -502,7 +521,7 @@ class PusherStore(PusherWorkerStore):
async def delete_pusher_by_app_id_pushkey_user_id(
self, app_id: str, pushkey: str, user_id: str
) -> None:
- def delete_pusher_txn(txn, stream_id):
+ def delete_pusher_txn(txn: LoggingTransaction, stream_id: int) -> None:
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_if_user_has_pusher, (user_id,)
)
@@ -547,7 +566,7 @@ class PusherStore(PusherWorkerStore):
# account.
pushers = list(await self.get_pushers_by_user_id(user_id))
- def delete_pushers_txn(txn, stream_ids):
+ def delete_pushers_txn(txn: LoggingTransaction, stream_ids: List[int]) -> None:
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_if_user_has_pusher, (user_id,)
)
|