summary refs log tree commit diff
path: root/synapse/storage/databases/main/pusher.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/pusher.py')
-rw-r--r--synapse/storage/databases/main/pusher.py93
1 files changed, 57 insertions, 36 deletions
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 7997242d90..77ba9d819e 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -15,18 +15,32 @@
 # limitations under the License.
 
 import logging
-from typing import Iterable, Iterator, List, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
 
 from canonicaljson import encode_canonical_json
 
+from synapse.push import PusherConfig, ThrottleParams
 from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import DatabasePool
+from synapse.storage.types import Connection
+from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached, cachedList
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class PusherWorkerStore(SQLBaseStore):
-    def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]:
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+        super().__init__(database, db_conn, hs)
+        self._pushers_id_gen = StreamIdGenerator(
+            db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
+        )
+
+    def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
         """JSON-decode the data in the rows returned from the `pushers` table
 
         Drops any rows whose data cannot be decoded
@@ -44,21 +58,23 @@ class PusherWorkerStore(SQLBaseStore):
                 )
                 continue
 
-            yield r
+            yield PusherConfig(**r)
 
-    async def user_has_pusher(self, user_id):
+    async def user_has_pusher(self, user_id: str) -> bool:
         ret = await self.db_pool.simple_select_one_onecol(
             "pushers", {"user_name": user_id}, "id", allow_none=True
         )
         return ret is not None
 
-    def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
-        return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
+    async def get_pushers_by_app_id_and_pushkey(
+        self, app_id: str, pushkey: str
+    ) -> Iterator[PusherConfig]:
+        return await self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
 
-    def get_pushers_by_user_id(self, user_id):
-        return self.get_pushers_by({"user_name": user_id})
+    async def get_pushers_by_user_id(self, user_id: str) -> Iterator[PusherConfig]:
+        return await self.get_pushers_by({"user_name": user_id})
 
-    async def get_pushers_by(self, keyvalues):
+    async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]:
         ret = await self.db_pool.simple_select_list(
             "pushers",
             keyvalues,
@@ -83,7 +99,7 @@ class PusherWorkerStore(SQLBaseStore):
         )
         return self._decode_pushers_rows(ret)
 
-    async def get_all_pushers(self):
+    async def get_all_pushers(self) -> Iterator[PusherConfig]:
         def get_pushers(txn):
             txn.execute("SELECT * FROM pushers")
             rows = self.db_pool.cursor_to_dict(txn)
@@ -159,14 +175,16 @@ class PusherWorkerStore(SQLBaseStore):
         )
 
     @cached(num_args=1, max_entries=15000)
-    async def get_if_user_has_pusher(self, user_id):
+    async def get_if_user_has_pusher(self, user_id: str):
         # This only exists for the cachedList decorator
         raise NotImplementedError()
 
     @cachedList(
         cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
     )
-    async def get_if_users_have_pushers(self, user_ids):
+    async def get_if_users_have_pushers(
+        self, user_ids: Iterable[str]
+    ) -> Dict[str, bool]:
         rows = await self.db_pool.simple_select_many_batch(
             table="pushers",
             column="user_name",
@@ -224,7 +242,7 @@ class PusherWorkerStore(SQLBaseStore):
         return bool(updated)
 
     async def update_pusher_failing_since(
-        self, app_id, pushkey, user_id, failing_since
+        self, app_id: str, pushkey: str, user_id: str, failing_since: Optional[int]
     ) -> None:
         await self.db_pool.simple_update(
             table="pushers",
@@ -233,7 +251,9 @@ class PusherWorkerStore(SQLBaseStore):
             desc="update_pusher_failing_since",
         )
 
-    async def get_throttle_params_by_room(self, pusher_id):
+    async def get_throttle_params_by_room(
+        self, pusher_id: str
+    ) -> Dict[str, ThrottleParams]:
         res = await self.db_pool.simple_select_list(
             "pusher_throttle",
             {"pusher": pusher_id},
@@ -243,43 +263,44 @@ class PusherWorkerStore(SQLBaseStore):
 
         params_by_room = {}
         for row in res:
-            params_by_room[row["room_id"]] = {
-                "last_sent_ts": row["last_sent_ts"],
-                "throttle_ms": row["throttle_ms"],
-            }
+            params_by_room[row["room_id"]] = ThrottleParams(
+                row["last_sent_ts"], row["throttle_ms"],
+            )
 
         return params_by_room
 
-    async def set_throttle_params(self, pusher_id, room_id, params) -> None:
+    async def set_throttle_params(
+        self, pusher_id: str, room_id: str, params: ThrottleParams
+    ) -> None:
         # no need to lock because `pusher_throttle` has a primary key on
         # (pusher, room_id) so simple_upsert will retry
         await self.db_pool.simple_upsert(
             "pusher_throttle",
             {"pusher": pusher_id, "room_id": room_id},
-            params,
+            {"last_sent_ts": params.last_sent_ts, "throttle_ms": params.throttle_ms},
             desc="set_throttle_params",
             lock=False,
         )
 
 
 class PusherStore(PusherWorkerStore):
-    def get_pushers_stream_token(self):
+    def get_pushers_stream_token(self) -> int:
         return self._pushers_id_gen.get_current_token()
 
     async def add_pusher(
         self,
-        user_id,
-        access_token,
-        kind,
-        app_id,
-        app_display_name,
-        device_display_name,
-        pushkey,
-        pushkey_ts,
-        lang,
-        data,
-        last_stream_ordering,
-        profile_tag="",
+        user_id: str,
+        access_token: Optional[int],
+        kind: str,
+        app_id: str,
+        app_display_name: str,
+        device_display_name: str,
+        pushkey: str,
+        pushkey_ts: int,
+        lang: Optional[str],
+        data: Optional[JsonDict],
+        last_stream_ordering: int,
+        profile_tag: str = "",
     ) -> None:
         async with self._pushers_id_gen.get_next() as stream_id:
             # no need to lock because `pushers` has a unique key on
@@ -311,16 +332,16 @@ class PusherStore(PusherWorkerStore):
                 # invalidate, since we the user might not have had a pusher before
                 await self.db_pool.runInteraction(
                     "add_pusher",
-                    self._invalidate_cache_and_stream,
+                    self._invalidate_cache_and_stream,  # type: ignore
                     self.get_if_user_has_pusher,
                     (user_id,),
                 )
 
     async def delete_pusher_by_app_id_pushkey_user_id(
-        self, app_id, pushkey, user_id
+        self, app_id: str, pushkey: str, user_id: str
     ) -> None:
         def delete_pusher_txn(txn, stream_id):
-            self._invalidate_cache_and_stream(
+            self._invalidate_cache_and_stream(  # type: ignore
                 txn, self.get_if_user_has_pusher, (user_id,)
             )