summary refs log tree commit diff
path: root/synapse/push/emailpusher.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/push/emailpusher.py')
-rw-r--r--synapse/push/emailpusher.py27
1 files changed, 14 insertions, 13 deletions
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 11a97b8df4..d2eff75a58 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -14,13 +14,13 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional
 
 from twisted.internet.base import DelayedCall
 from twisted.internet.error import AlreadyCalled, AlreadyCancelled
 
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.push import Pusher
+from synapse.push import Pusher, PusherConfig, ThrottleParams
 from synapse.push.mailer import Mailer
 
 if TYPE_CHECKING:
@@ -60,15 +60,14 @@ class EmailPusher(Pusher):
     factor out the common parts
     """
 
-    def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any], mailer: Mailer):
-        super().__init__(hs, pusherdict)
+    def __init__(self, hs: "HomeServer", pusher_config: PusherConfig, mailer: Mailer):
+        super().__init__(hs, pusher_config)
         self.mailer = mailer
 
         self.store = self.hs.get_datastore()
-        self.email = pusherdict["pushkey"]
-        self.last_stream_ordering = pusherdict["last_stream_ordering"]
+        self.email = pusher_config.pushkey
         self.timed_call = None  # type: Optional[DelayedCall]
-        self.throttle_params = {}  # type: Dict[str, Dict[str, int]]
+        self.throttle_params = {}  # type: Dict[str, ThrottleParams]
         self._inited = False
 
         self._is_processing = False
@@ -132,6 +131,7 @@ class EmailPusher(Pusher):
 
             if not self._inited:
                 # this is our first loop: load up the throttle params
+                assert self.pusher_id is not None
                 self.throttle_params = await self.store.get_throttle_params_by_room(
                     self.pusher_id
                 )
@@ -157,6 +157,7 @@ class EmailPusher(Pusher):
         being run.
         """
         start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
+        assert start is not None
         unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email(
             self.user_id, start, self.max_stream_ordering
         )
@@ -244,13 +245,13 @@ class EmailPusher(Pusher):
 
     def get_room_throttle_ms(self, room_id: str) -> int:
         if room_id in self.throttle_params:
-            return self.throttle_params[room_id]["throttle_ms"]
+            return self.throttle_params[room_id].throttle_ms
         else:
             return 0
 
     def get_room_last_sent_ts(self, room_id: str) -> int:
         if room_id in self.throttle_params:
-            return self.throttle_params[room_id]["last_sent_ts"]
+            return self.throttle_params[room_id].last_sent_ts
         else:
             return 0
 
@@ -301,10 +302,10 @@ class EmailPusher(Pusher):
                 new_throttle_ms = min(
                     current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS
                 )
-        self.throttle_params[room_id] = {
-            "last_sent_ts": self.clock.time_msec(),
-            "throttle_ms": new_throttle_ms,
-        }
+        self.throttle_params[room_id] = ThrottleParams(
+            self.clock.time_msec(), new_throttle_ms,
+        )
+        assert self.pusher_id is not None
         await self.store.set_throttle_params(
             self.pusher_id, room_id, self.throttle_params[room_id]
         )