summary refs log tree commit diff
path: root/synapse/storage/databases/main/pusher.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/pusher.py')
-rw-r--r--synapse/storage/databases/main/pusher.py49
1 files changed, 34 insertions, 15 deletions
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index cf64cd63a4..91286c9b65 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -14,11 +14,25 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Iterable,
+    Iterator,
+    List,
+    Optional,
+    Tuple,
+    cast,
+)
 
 from synapse.push import PusherConfig, ThrottleParams
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
@@ -117,7 +131,7 @@ class PusherWorkerStore(SQLBaseStore):
         return self._decode_pushers_rows(ret)
 
     async def get_all_pushers(self) -> Iterator[PusherConfig]:
-        def get_pushers(txn):
+        def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]:
             txn.execute("SELECT * FROM pushers")
             rows = self.db_pool.cursor_to_dict(txn)
 
@@ -152,7 +166,9 @@ class PusherWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_updated_pushers_rows_txn(txn):
+        def get_all_updated_pushers_rows_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
             sql = """
                 SELECT id, user_name, app_id, pushkey
                 FROM pushers
@@ -160,10 +176,13 @@ class PusherWorkerStore(SQLBaseStore):
                 ORDER BY id ASC LIMIT ?
             """
             txn.execute(sql, (last_id, current_id, limit))
-            updates = [
-                (stream_id, (user_name, app_id, pushkey, False))
-                for stream_id, user_name, app_id, pushkey in txn
-            ]
+            updates = cast(
+                List[Tuple[int, tuple]],
+                [
+                    (stream_id, (user_name, app_id, pushkey, False))
+                    for stream_id, user_name, app_id, pushkey in txn
+                ],
+            )
 
             sql = """
                 SELECT stream_id, user_id, app_id, pushkey
@@ -192,12 +211,12 @@ class PusherWorkerStore(SQLBaseStore):
         )
 
     @cached(num_args=1, max_entries=15000)
-    async def get_if_user_has_pusher(self, user_id: str):
+    async def get_if_user_has_pusher(self, user_id: str) -> None:
         # This only exists for the cachedList decorator
         raise NotImplementedError()
 
     async def update_pusher_last_stream_ordering(
-        self, app_id, pushkey, user_id, last_stream_ordering
+        self, app_id: str, pushkey: str, user_id: str, last_stream_ordering: int
     ) -> None:
         await self.db_pool.simple_update_one(
             "pushers",
@@ -291,7 +310,7 @@ class PusherWorkerStore(SQLBaseStore):
 
         last_user = progress.get("last_user", "")
 
-        def _delete_pushers(txn) -> int:
+        def _delete_pushers(txn: LoggingTransaction) -> int:
 
             sql = """
                 SELECT name FROM users
@@ -339,7 +358,7 @@ class PusherWorkerStore(SQLBaseStore):
 
         last_pusher = progress.get("last_pusher", 0)
 
-        def _delete_pushers(txn) -> int:
+        def _delete_pushers(txn: LoggingTransaction) -> int:
 
             sql = """
                 SELECT p.id, access_token FROM pushers AS p
@@ -396,7 +415,7 @@ class PusherWorkerStore(SQLBaseStore):
 
         last_pusher = progress.get("last_pusher", 0)
 
-        def _delete_pushers(txn) -> int:
+        def _delete_pushers(txn: LoggingTransaction) -> int:
 
             sql = """
                 SELECT p.id, p.user_name, p.app_id, p.pushkey
@@ -502,7 +521,7 @@ class PusherStore(PusherWorkerStore):
     async def delete_pusher_by_app_id_pushkey_user_id(
         self, app_id: str, pushkey: str, user_id: str
     ) -> None:
-        def delete_pusher_txn(txn, stream_id):
+        def delete_pusher_txn(txn: LoggingTransaction, stream_id: int) -> None:
             self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
                 txn, self.get_if_user_has_pusher, (user_id,)
             )
@@ -547,7 +566,7 @@ class PusherStore(PusherWorkerStore):
         # account.
         pushers = list(await self.get_pushers_by_user_id(user_id))
 
-        def delete_pushers_txn(txn, stream_ids):
+        def delete_pushers_txn(txn: LoggingTransaction, stream_ids: List[int]) -> None:
             self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
                 txn, self.get_if_user_has_pusher, (user_id,)
             )