summary refs log tree commit diff
path: root/synapse/push/pusherpool.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/push/pusherpool.py')
-rw-r--r--synapse/push/pusherpool.py168
1 files changed, 86 insertions, 82 deletions
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index f325964983..8158356d40 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, Union
+from typing import TYPE_CHECKING, Dict, Iterable, Optional
 
 from prometheus_client import Gauge
 
@@ -23,11 +23,9 @@ from synapse.metrics.background_process_metrics import (
     run_as_background_process,
     wrap_as_background_process,
 )
-from synapse.push import PusherConfigException
-from synapse.push.emailpusher import EmailPusher
-from synapse.push.httppusher import HttpPusher
+from synapse.push import Pusher, PusherConfig, PusherConfigException
 from synapse.push.pusher import PusherFactory
-from synapse.types import RoomStreamToken
+from synapse.types import JsonDict, RoomStreamToken
 from synapse.util.async_helpers import concurrently_execute
 
 if TYPE_CHECKING:
@@ -77,9 +75,9 @@ class PusherPool:
         self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering()
 
         # map from user id to app_id:pushkey to pusher
-        self.pushers = {}  # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]]
+        self.pushers = {}  # type: Dict[str, Dict[str, Pusher]]
 
-    def start(self):
+    def start(self) -> None:
         """Starts the pushers off in a background process.
         """
         if not self._should_start_pushers:
@@ -89,21 +87,21 @@ class PusherPool:
 
     async def add_pusher(
         self,
-        user_id,
-        access_token,
-        kind,
-        app_id,
-        app_display_name,
-        device_display_name,
-        pushkey,
-        lang,
-        data,
-        profile_tag="",
-    ):
+        user_id: str,
+        access_token: Optional[int],
+        kind: str,
+        app_id: str,
+        app_display_name: str,
+        device_display_name: str,
+        pushkey: str,
+        lang: Optional[str],
+        data: JsonDict,
+        profile_tag: str = "",
+    ) -> Optional[Pusher]:
         """Creates a new pusher and adds it to the pool
 
         Returns:
-            EmailPusher|HttpPusher
+            The newly created pusher.
         """
 
         time_now_msec = self.clock.time_msec()
@@ -113,27 +111,28 @@ class PusherPool:
         # recreated, added and started: this means we have only one
         # code path adding pushers.
         self.pusher_factory.create_pusher(
-            {
-                "id": None,
-                "user_name": user_id,
-                "kind": kind,
-                "app_id": app_id,
-                "app_display_name": app_display_name,
-                "device_display_name": device_display_name,
-                "pushkey": pushkey,
-                "ts": time_now_msec,
-                "lang": lang,
-                "data": data,
-                "last_stream_ordering": None,
-                "last_success": None,
-                "failing_since": None,
-            }
+            PusherConfig(
+                id=None,
+                user_name=user_id,
+                access_token=access_token,
+                profile_tag=profile_tag,
+                kind=kind,
+                app_id=app_id,
+                app_display_name=app_display_name,
+                device_display_name=device_display_name,
+                pushkey=pushkey,
+                ts=time_now_msec,
+                lang=lang,
+                data=data,
+                last_stream_ordering=None,
+                last_success=None,
+                failing_since=None,
+            )
         )
 
         # create the pusher setting last_stream_ordering to the current maximum
-        # stream ordering in event_push_actions, so it will process
-        # pushes from this point onwards.
-        last_stream_ordering = await self.store.get_latest_push_action_stream_ordering()
+        # stream ordering, so it will process pushes from this point onwards.
+        last_stream_ordering = self.store.get_room_max_stream_ordering()
 
         await self.store.add_pusher(
             user_id=user_id,
@@ -154,43 +153,44 @@ class PusherPool:
         return pusher
 
     async def remove_pushers_by_app_id_and_pushkey_not_user(
-        self, app_id, pushkey, not_user_id
-    ):
+        self, app_id: str, pushkey: str, not_user_id: str
+    ) -> None:
         to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
         for p in to_remove:
-            if p["user_name"] != not_user_id:
+            if p.user_name != not_user_id:
                 logger.info(
                     "Removing pusher for app id %s, pushkey %s, user %s",
                     app_id,
                     pushkey,
-                    p["user_name"],
+                    p.user_name,
                 )
-                await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
+                await self.remove_pusher(p.app_id, p.pushkey, p.user_name)
 
-    async def remove_pushers_by_access_token(self, user_id, access_tokens):
+    async def remove_pushers_by_access_token(
+        self, user_id: str, access_tokens: Iterable[int]
+    ) -> None:
         """Remove the pushers for a given user corresponding to a set of
         access_tokens.
 
         Args:
-            user_id (str): user to remove pushers for
-            access_tokens (Iterable[int]): access token *ids* to remove pushers
-                for
+            user_id: user to remove pushers for
+            access_tokens: access token *ids* to remove pushers for
         """
         if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
             return
 
         tokens = set(access_tokens)
         for p in await self.store.get_pushers_by_user_id(user_id):
-            if p["access_token"] in tokens:
+            if p.access_token in tokens:
                 logger.info(
                     "Removing pusher for app id %s, pushkey %s, user %s",
-                    p["app_id"],
-                    p["pushkey"],
-                    p["user_name"],
+                    p.app_id,
+                    p.pushkey,
+                    p.user_name,
                 )
-                await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
+                await self.remove_pusher(p.app_id, p.pushkey, p.user_name)
 
-    def on_new_notifications(self, max_token: RoomStreamToken):
+    def on_new_notifications(self, max_token: RoomStreamToken) -> None:
         if not self.pushers:
             # nothing to do here.
             return
@@ -209,7 +209,7 @@ class PusherPool:
         self._on_new_notifications(max_token)
 
     @wrap_as_background_process("on_new_notifications")
-    async def _on_new_notifications(self, max_token: RoomStreamToken):
+    async def _on_new_notifications(self, max_token: RoomStreamToken) -> None:
         # We just use the minimum stream ordering and ignore the vector clock
         # component. This is safe to do as long as we *always* ignore the vector
         # clock components.
@@ -239,7 +239,9 @@ class PusherPool:
         except Exception:
             logger.exception("Exception in pusher on_new_notifications")
 
-    async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
+    async def on_new_receipts(
+        self, min_stream_id: int, max_stream_id: int, affected_room_ids: Iterable[str]
+    ) -> None:
         if not self.pushers:
             # nothing to do here.
             return
@@ -267,28 +269,30 @@ class PusherPool:
         except Exception:
             logger.exception("Exception in pusher on_new_receipts")
 
-    async def start_pusher_by_id(self, app_id, pushkey, user_id):
+    async def start_pusher_by_id(
+        self, app_id: str, pushkey: str, user_id: str
+    ) -> Optional[Pusher]:
         """Look up the details for the given pusher, and start it
 
         Returns:
-            EmailPusher|HttpPusher|None: The pusher started, if any
+            The pusher started, if any
         """
         if not self._should_start_pushers:
-            return
+            return None
 
         if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
-            return
+            return None
 
         resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
 
-        pusher_dict = None
+        pusher_config = None
         for r in resultlist:
-            if r["user_name"] == user_id:
-                pusher_dict = r
+            if r.user_name == user_id:
+                pusher_config = r
 
         pusher = None
-        if pusher_dict:
-            pusher = await self._start_pusher(pusher_dict)
+        if pusher_config:
+            pusher = await self._start_pusher(pusher_config)
 
         return pusher
 
@@ -303,44 +307,44 @@ class PusherPool:
 
         logger.info("Started pushers")
 
-    async def _start_pusher(self, pusherdict):
+    async def _start_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]:
         """Start the given pusher
 
         Args:
-            pusherdict (dict): dict with the values pulled from the db table
+            pusher_config: The pusher configuration with the values pulled from the db table
 
         Returns:
-            EmailPusher|HttpPusher
+            The newly created pusher or None.
         """
         if not self._pusher_shard_config.should_handle(
-            self._instance_name, pusherdict["user_name"]
+            self._instance_name, pusher_config.user_name
         ):
-            return
+            return None
 
         try:
-            p = self.pusher_factory.create_pusher(pusherdict)
+            p = self.pusher_factory.create_pusher(pusher_config)
         except PusherConfigException as e:
             logger.warning(
                 "Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s",
-                pusherdict["id"],
-                pusherdict.get("user_name"),
-                pusherdict.get("app_id"),
-                pusherdict.get("pushkey"),
+                pusher_config.id,
+                pusher_config.user_name,
+                pusher_config.app_id,
+                pusher_config.pushkey,
                 e,
             )
-            return
+            return None
         except Exception:
             logger.exception(
-                "Couldn't start pusher id %i: caught Exception", pusherdict["id"],
+                "Couldn't start pusher id %i: caught Exception", pusher_config.id,
             )
-            return
+            return None
 
         if not p:
-            return
+            return None
 
-        appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])
+        appid_pushkey = "%s:%s" % (pusher_config.app_id, pusher_config.pushkey)
 
-        byuser = self.pushers.setdefault(pusherdict["user_name"], {})
+        byuser = self.pushers.setdefault(pusher_config.user_name, {})
         if appid_pushkey in byuser:
             byuser[appid_pushkey].on_stop()
         byuser[appid_pushkey] = p
@@ -350,8 +354,8 @@ class PusherPool:
         # Check if there *may* be push to process. We do this as this check is a
         # lot cheaper to do than actually fetching the exact rows we need to
         # push.
-        user_id = pusherdict["user_name"]
-        last_stream_ordering = pusherdict["last_stream_ordering"]
+        user_id = pusher_config.user_name
+        last_stream_ordering = pusher_config.last_stream_ordering
         if last_stream_ordering:
             have_notifs = await self.store.get_if_maybe_push_in_range_for_user(
                 user_id, last_stream_ordering
@@ -365,7 +369,7 @@ class PusherPool:
 
         return p
 
-    async def remove_pusher(self, app_id, pushkey, user_id):
+    async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
         appid_pushkey = "%s:%s" % (app_id, pushkey)
 
         byuser = self.pushers.get(user_id, {})