summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-12-21 08:25:34 -0500
committerGitHub <noreply@github.com>2021-12-21 13:25:34 +0000
commitb6102230a7391d1acaa50cc6c389813f7e0fab84 (patch)
tree605757fe7627d00ea873fb7a69f63128fbe53432
parentVarious opentracing enhancements (#11619) (diff)
downloadsynapse-b6102230a7391d1acaa50cc6c389813f7e0fab84.tar.xz
Add type hints to event_push_actions. (#11594)
-rw-r--r--changelog.d/11594.misc1
-rw-r--r--mypy.ini4
-rw-r--r--synapse/handlers/sync.py12
-rw-r--r--synapse/push/emailpusher.py18
-rw-r--r--synapse/push/httppusher.py12
-rw-r--r--synapse/push/mailer.py40
-rw-r--r--synapse/push/push_tools.py4
-rw-r--r--synapse/rest/client/notifications.py20
-rw-r--r--synapse/storage/databases/main/event_push_actions.py249
-rw-r--r--tests/replication/slave/storage/test_events.py7
-rw-r--r--tests/storage/test_event_push_actions.py12
11 files changed, 225 insertions, 154 deletions
diff --git a/changelog.d/11594.misc b/changelog.d/11594.misc
new file mode 100644
index 0000000000..d451940bf2
--- /dev/null
+++ b/changelog.d/11594.misc
@@ -0,0 +1 @@
+Add missing type hints to storage classes.
diff --git a/mypy.ini b/mypy.ini
index 3279c9bb21..57e1a5df43 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -28,7 +28,6 @@ exclude = (?x)
    |synapse/storage/databases/main/cache.py
    |synapse/storage/databases/main/devices.py
    |synapse/storage/databases/main/event_federation.py
-   |synapse/storage/databases/main/event_push_actions.py
    |synapse/storage/databases/main/events_bg_updates.py
    |synapse/storage/databases/main/group_server.py
    |synapse/storage/databases/main/metrics.py
@@ -200,6 +199,9 @@ disallow_untyped_defs = True
 [mypy-synapse.storage.databases.main.end_to_end_keys]
 disallow_untyped_defs = True
 
+[mypy-synapse.storage.databases.main.event_push_actions]
+disallow_untyped_defs = True
+
 [mypy-synapse.storage.databases.main.events_worker]
 disallow_untyped_defs = True
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index d24124d6ac..7baf3f199c 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -36,6 +36,7 @@ from synapse.events import EventBase
 from synapse.logging.context import current_context
 from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
 from synapse.push.clientformat import format_push_rules_for_user
+from synapse.storage.databases.main.event_push_actions import NotifCounts
 from synapse.storage.roommember import MemberSummary
 from synapse.storage.state import StateFilter
 from synapse.types import (
@@ -1041,7 +1042,7 @@ class SyncHandler:
 
     async def unread_notifs_for_room_id(
         self, room_id: str, sync_config: SyncConfig
-    ) -> Dict[str, int]:
+    ) -> NotifCounts:
         with Measure(self.clock, "unread_notifs_for_room_id"):
             last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
                 user_id=sync_config.user.to_string(),
@@ -1049,10 +1050,9 @@ class SyncHandler:
                 receipt_type=ReceiptTypes.READ,
             )
 
-            notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
+            return await self.store.get_unread_event_push_actions_by_room_for_user(
                 room_id, sync_config.user.to_string(), last_unread_event_id
             )
-            return notifs
 
     async def generate_sync_result(
         self,
@@ -2174,10 +2174,10 @@ class SyncHandler:
                 if room_sync or always_include:
                     notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
 
-                    unread_notifications["notification_count"] = notifs["notify_count"]
-                    unread_notifications["highlight_count"] = notifs["highlight_count"]
+                    unread_notifications["notification_count"] = notifs.notify_count
+                    unread_notifications["highlight_count"] = notifs.highlight_count
 
-                    room_sync.unread_count = notifs["unread_count"]
+                    room_sync.unread_count = notifs.unread_count
 
                     sync_result_builder.joined.append(room_sync)
 
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 4f13c0418a..39bb2acae4 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -177,12 +177,12 @@ class EmailPusher(Pusher):
             return
 
         for push_action in unprocessed:
-            received_at = push_action["received_ts"]
+            received_at = push_action.received_ts
             if received_at is None:
                 received_at = 0
             notif_ready_at = received_at + DELAY_BEFORE_MAIL_MS
 
-            room_ready_at = self.room_ready_to_notify_at(push_action["room_id"])
+            room_ready_at = self.room_ready_to_notify_at(push_action.room_id)
 
             should_notify_at = max(notif_ready_at, room_ready_at)
 
@@ -193,23 +193,23 @@ class EmailPusher(Pusher):
                 # to be delivered.
 
                 reason: EmailReason = {
-                    "room_id": push_action["room_id"],
+                    "room_id": push_action.room_id,
                     "now": self.clock.time_msec(),
                     "received_at": received_at,
                     "delay_before_mail_ms": DELAY_BEFORE_MAIL_MS,
-                    "last_sent_ts": self.get_room_last_sent_ts(push_action["room_id"]),
-                    "throttle_ms": self.get_room_throttle_ms(push_action["room_id"]),
+                    "last_sent_ts": self.get_room_last_sent_ts(push_action.room_id),
+                    "throttle_ms": self.get_room_throttle_ms(push_action.room_id),
                 }
 
                 await self.send_notification(unprocessed, reason)
 
                 await self.save_last_stream_ordering_and_success(
-                    max(ea["stream_ordering"] for ea in unprocessed)
+                    max(ea.stream_ordering for ea in unprocessed)
                 )
 
                 # we update the throttle on all the possible unprocessed push actions
                 for ea in unprocessed:
-                    await self.sent_notif_update_throttle(ea["room_id"], ea)
+                    await self.sent_notif_update_throttle(ea.room_id, ea)
                 break
             else:
                 if soonest_due_at is None or should_notify_at < soonest_due_at:
@@ -284,10 +284,10 @@ class EmailPusher(Pusher):
         # THROTTLE_RESET_AFTER_MS after the previous one that triggered a
         # notif, we release the throttle. Otherwise, the throttle is increased.
         time_of_previous_notifs = await self.store.get_time_of_last_push_action_before(
-            notified_push_action["stream_ordering"]
+            notified_push_action.stream_ordering
         )
 
-        time_of_this_notifs = notified_push_action["received_ts"]
+        time_of_this_notifs = notified_push_action.received_ts
 
         if time_of_previous_notifs is not None and time_of_this_notifs is not None:
             gap = time_of_this_notifs - time_of_previous_notifs
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 3fa603ccb7..96559081d0 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -199,7 +199,7 @@ class HttpPusher(Pusher):
                 "http-push",
                 tags={
                     "authenticated_entity": self.user_id,
-                    "event_id": push_action["event_id"],
+                    "event_id": push_action.event_id,
                     "app_id": self.app_id,
                     "app_display_name": self.app_display_name,
                 },
@@ -209,7 +209,7 @@ class HttpPusher(Pusher):
             if processed:
                 http_push_processed_counter.inc()
                 self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
-                self.last_stream_ordering = push_action["stream_ordering"]
+                self.last_stream_ordering = push_action.stream_ordering
                 pusher_still_exists = (
                     await self.store.update_pusher_last_stream_ordering_and_success(
                         self.app_id,
@@ -252,7 +252,7 @@ class HttpPusher(Pusher):
                         self.pushkey,
                     )
                     self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
-                    self.last_stream_ordering = push_action["stream_ordering"]
+                    self.last_stream_ordering = push_action.stream_ordering
                     await self.store.update_pusher_last_stream_ordering(
                         self.app_id,
                         self.pushkey,
@@ -275,17 +275,17 @@ class HttpPusher(Pusher):
                     break
 
     async def _process_one(self, push_action: HttpPushAction) -> bool:
-        if "notify" not in push_action["actions"]:
+        if "notify" not in push_action.actions:
             return True
 
-        tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"])
+        tweaks = push_rule_evaluator.tweaks_for_actions(push_action.actions)
         badge = await push_tools.get_badge_count(
             self.hs.get_datastore(),
             self.user_id,
             group_by_room=self._group_unread_count_by_room,
         )
 
-        event = await self.store.get_event(push_action["event_id"], allow_none=True)
+        event = await self.store.get_event(push_action.event_id, allow_none=True)
         if event is None:
             return True  # It's been redacted
         rejected = await self.dispatch_push(event, tweaks, badge)
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index ba4f866487..ff904c2b4a 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -232,15 +232,13 @@ class Mailer:
             reason: The notification that was ready and is the cause of an email
                 being sent.
         """
-        rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions])
+        rooms_in_order = deduped_ordered_list([pa.room_id for pa in push_actions])
 
-        notif_events = await self.store.get_events(
-            [pa["event_id"] for pa in push_actions]
-        )
+        notif_events = await self.store.get_events([pa.event_id for pa in push_actions])
 
         notifs_by_room: Dict[str, List[EmailPushAction]] = {}
         for pa in push_actions:
-            notifs_by_room.setdefault(pa["room_id"], []).append(pa)
+            notifs_by_room.setdefault(pa.room_id, []).append(pa)
 
         # collect the current state for all the rooms in which we have
         # notifications
@@ -264,7 +262,7 @@ class Mailer:
         await concurrently_execute(_fetch_room_state, rooms_in_order, 3)
 
         # actually sort our so-called rooms_in_order list, most recent room first
-        rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0))
+        rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1].received_ts or 0))
 
         rooms: List[RoomVars] = []
 
@@ -356,7 +354,7 @@ class Mailer:
         # Check if one of the notifs is an invite event for the user.
         is_invite = False
         for n in notifs:
-            ev = notif_events[n["event_id"]]
+            ev = notif_events[n.event_id]
             if ev.type == EventTypes.Member and ev.state_key == user_id:
                 if ev.content.get("membership") == Membership.INVITE:
                     is_invite = True
@@ -376,7 +374,7 @@ class Mailer:
         if not is_invite:
             for n in notifs:
                 notifvars = await self._get_notif_vars(
-                    n, user_id, notif_events[n["event_id"]], room_state_ids
+                    n, user_id, notif_events[n.event_id], room_state_ids
                 )
 
                 # merge overlapping notifs together.
@@ -444,15 +442,15 @@ class Mailer:
         """
 
         results = await self.store.get_events_around(
-            notif["room_id"],
-            notif["event_id"],
+            notif.room_id,
+            notif.event_id,
             before_limit=CONTEXT_BEFORE,
             after_limit=CONTEXT_AFTER,
         )
 
         ret: NotifVars = {
             "link": self._make_notif_link(notif),
-            "ts": notif["received_ts"],
+            "ts": notif.received_ts,
             "messages": [],
         }
 
@@ -516,7 +514,7 @@ class Mailer:
 
         ret: MessageVars = {
             "event_type": event.type,
-            "is_historical": event.event_id != notif["event_id"],
+            "is_historical": event.event_id != notif.event_id,
             "id": event.event_id,
             "ts": event.origin_server_ts,
             "sender_name": sender_name,
@@ -610,7 +608,7 @@ class Mailer:
         # See if one of the notifs is an invite event for the user
         invite_event = None
         for n in notifs:
-            ev = notif_events[n["event_id"]]
+            ev = notif_events[n.event_id]
             if ev.type == EventTypes.Member and ev.state_key == user_id:
                 if ev.content.get("membership") == Membership.INVITE:
                     invite_event = ev
@@ -659,7 +657,7 @@ class Mailer:
         if len(notifs) == 1:
             # There is just the one notification, so give some detail
             sender_name = None
-            event = notif_events[notifs[0]["event_id"]]
+            event = notif_events[notifs[0].event_id]
             if ("m.room.member", event.sender) in room_state_ids:
                 state_event_id = room_state_ids[("m.room.member", event.sender)]
                 state_event = await self.store.get_event(state_event_id)
@@ -753,9 +751,9 @@ class Mailer:
         # are already in descending received_ts.
         sender_ids = {}
         for n in notifs:
-            sender = notif_events[n["event_id"]].sender
+            sender = notif_events[n.event_id].sender
             if sender not in sender_ids:
-                sender_ids[sender] = n["event_id"]
+                sender_ids[sender] = n.event_id
 
         # Get the actual member events (in order to calculate a pretty name for
         # the room).
@@ -830,17 +828,17 @@ class Mailer:
         if self.hs.config.email.email_riot_base_url:
             return "%s/#/room/%s/%s" % (
                 self.hs.config.email.email_riot_base_url,
-                notif["room_id"],
-                notif["event_id"],
+                notif.room_id,
+                notif.event_id,
             )
         elif self.app_name == "Vector":
             # need /beta for Universal Links to work on iOS
             return "https://vector.im/beta/#/room/%s/%s" % (
-                notif["room_id"],
-                notif["event_id"],
+                notif.room_id,
+                notif.event_id,
             )
         else:
-            return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"])
+            return "https://matrix.to/#/%s/%s" % (notif.room_id, notif.event_id)
 
     def _make_unsubscribe_link(
         self, user_id: str, app_id: str, email_address: str
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index da641aca47..957c9b780b 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -37,7 +37,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
                     room_id, user_id, last_unread_event_id
                 )
             )
-            if notifs["notify_count"] == 0:
+            if notifs.notify_count == 0:
                 continue
 
             if group_by_room:
@@ -45,7 +45,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
                 badge += 1
             else:
                 # increment the badge count by the number of unread messages in the room
-                badge += notifs["notify_count"]
+                badge += notifs.notify_count
     return badge
 
 
diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py
index b12a332776..acd0c9e135 100644
--- a/synapse/rest/client/notifications.py
+++ b/synapse/rest/client/notifications.py
@@ -58,7 +58,7 @@ class NotificationsServlet(RestServlet):
             user_id, ReceiptTypes.READ
         )
 
-        notif_event_ids = [pa["event_id"] for pa in push_actions]
+        notif_event_ids = [pa.event_id for pa in push_actions]
         notif_events = await self.store.get_events(notif_event_ids)
 
         returned_push_actions = []
@@ -67,30 +67,30 @@ class NotificationsServlet(RestServlet):
 
         for pa in push_actions:
             returned_pa = {
-                "room_id": pa["room_id"],
-                "profile_tag": pa["profile_tag"],
-                "actions": pa["actions"],
-                "ts": pa["received_ts"],
+                "room_id": pa.room_id,
+                "profile_tag": pa.profile_tag,
+                "actions": pa.actions,
+                "ts": pa.received_ts,
                 "event": (
                     await self._event_serializer.serialize_event(
-                        notif_events[pa["event_id"]],
+                        notif_events[pa.event_id],
                         self.clock.time_msec(),
                         event_format=format_event_for_client_v2_without_room_id,
                     )
                 ),
             }
 
-            if pa["room_id"] not in receipts_by_room:
+            if pa.room_id not in receipts_by_room:
                 returned_pa["read"] = False
             else:
-                receipt = receipts_by_room[pa["room_id"]]
+                receipt = receipts_by_room[pa.room_id]
 
                 returned_pa["read"] = (
                     receipt["topological_ordering"],
                     receipt["stream_ordering"],
-                ) >= (pa["topological_ordering"], pa["stream_ordering"])
+                ) >= (pa.topological_ordering, pa.stream_ordering)
             returned_push_actions.append(returned_pa)
-            next_token = str(pa["stream_ordering"])
+            next_token = str(pa.stream_ordering)
 
         return 200, {"notifications": returned_push_actions, "next_token": next_token}
 
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index eacff3e432..98ea0e884c 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -16,7 +16,6 @@ import logging
 from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
 
 import attr
-from typing_extensions import TypedDict
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -34,29 +33,64 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}]
-DEFAULT_HIGHLIGHT_ACTION = [
+DEFAULT_NOTIF_ACTION: List[Union[dict, str]] = [
+    "notify",
+    {"set_tweak": "highlight", "value": False},
+]
+DEFAULT_HIGHLIGHT_ACTION: List[Union[dict, str]] = [
     "notify",
     {"set_tweak": "sound", "value": "default"},
     {"set_tweak": "highlight"},
 ]
 
 
-class BasePushAction(TypedDict):
-    event_id: str
-    actions: List[Union[dict, str]]
-
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class HttpPushAction:
+    """
+    HttpPushAction instances include the information used to generate HTTP
+    requests to a push gateway.
+    """
 
-class HttpPushAction(BasePushAction):
+    event_id: str
     room_id: str
     stream_ordering: int
+    actions: List[Union[dict, str]]
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class EmailPushAction(HttpPushAction):
+    """
+    EmailPushAction instances include the information used to render an email
+    push notification.
+    """
+
     received_ts: Optional[int]
 
 
-def _serialize_action(actions, is_highlight):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class UserPushAction(EmailPushAction):
+    """
+    UserPushAction instances include the necessary information to respond to
+    /notifications requests.
+    """
+
+    topological_ordering: int
+    highlight: bool
+    profile_tag: str
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class NotifCounts:
+    """
+    The per-user, per-room count of notifications. Used by sync and push.
+    """
+
+    notify_count: int
+    unread_count: int
+    highlight_count: int
+
+
+def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str:
     """Custom serializer for actions. This allows us to "compress" common actions.
 
     We use the fact that most users have the same actions for notifs (and for
@@ -74,7 +108,7 @@ def _serialize_action(actions, is_highlight):
     return json_encoder.encode(actions)
 
 
-def _deserialize_action(actions, is_highlight):
+def _deserialize_action(actions: str, is_highlight: bool) -> List[Union[dict, str]]:
     """Custom deserializer for actions. This allows us to "compress" common actions"""
     if actions:
         return db_to_json(actions)
@@ -95,8 +129,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         super().__init__(database, db_conn, hs)
 
         # These get correctly set by _find_stream_orderings_for_times_txn
-        self.stream_ordering_month_ago = None
-        self.stream_ordering_day_ago = None
+        self.stream_ordering_month_ago: Optional[int] = None
+        self.stream_ordering_day_ago: Optional[int] = None
 
         cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn")
         self._find_stream_orderings_for_times_txn(cur)
@@ -120,7 +154,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         room_id: str,
         user_id: str,
         last_read_event_id: Optional[str],
-    ) -> Dict[str, int]:
+    ) -> NotifCounts:
         """Get the notification count, the highlight count and the unread message count
         for a given user in a given room after the given read receipt.
 
@@ -149,15 +183,15 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
     def _get_unread_counts_by_receipt_txn(
         self,
-        txn,
-        room_id,
-        user_id,
-        last_read_event_id,
-    ):
+        txn: LoggingTransaction,
+        room_id: str,
+        user_id: str,
+        last_read_event_id: Optional[str],
+    ) -> NotifCounts:
         stream_ordering = None
 
         if last_read_event_id is not None:
-            stream_ordering = self.get_stream_id_for_event_txn(
+            stream_ordering = self.get_stream_id_for_event_txn(  # type: ignore[attr-defined]
                 txn,
                 last_read_event_id,
                 allow_none=True,
@@ -175,13 +209,15 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 retcol="event_id",
             )
 
-            stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
+            stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)  # type: ignore[attr-defined]
 
         return self._get_unread_counts_by_pos_txn(
             txn, room_id, user_id, stream_ordering
         )
 
-    def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
+    def _get_unread_counts_by_pos_txn(
+        self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
+    ) -> NotifCounts:
         sql = (
             "SELECT"
             "   COUNT(CASE WHEN notif = 1 THEN 1 END),"
@@ -219,16 +255,16 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 # for this row.
                 unread_count += row[1]
 
-        return {
-            "notify_count": notif_count,
-            "unread_count": unread_count,
-            "highlight_count": highlight_count,
-        }
+        return NotifCounts(
+            notify_count=notif_count,
+            unread_count=unread_count,
+            highlight_count=highlight_count,
+        )
 
     async def get_push_action_users_in_range(
-        self, min_stream_ordering, max_stream_ordering
-    ):
-        def f(txn):
+        self, min_stream_ordering: int, max_stream_ordering: int
+    ) -> List[str]:
+        def f(txn: LoggingTransaction) -> List[str]:
             sql = (
                 "SELECT DISTINCT(user_id) FROM event_push_actions WHERE"
                 " stream_ordering >= ? AND stream_ordering <= ? AND notif = 1"
@@ -236,8 +272,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, (min_stream_ordering, max_stream_ordering))
             return [r[0] for r in txn]
 
-        ret = await self.db_pool.runInteraction("get_push_action_users_in_range", f)
-        return ret
+        return await self.db_pool.runInteraction("get_push_action_users_in_range", f)
 
     async def get_unread_push_actions_for_user_in_range_for_http(
         self,
@@ -263,7 +298,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         """
         # find rooms that have a read receipt in them and return the next
         # push actions
-        def get_after_receipt(txn):
+        def get_after_receipt(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[str, str, int, str, bool]]:
             # find rooms that have a read receipt in them and return the next
             # push actions
             sql = (
@@ -289,7 +326,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
             txn.execute(sql, args)
-            return txn.fetchall()
+            return txn.fetchall()  # type: ignore[return-value]
 
         after_read_receipt = await self.db_pool.runInteraction(
             "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
@@ -298,7 +335,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         # There are rooms with push actions in them but you don't have a read receipt in
         # them e.g. rooms you've been invited to, so get push actions for rooms which do
         # not have read receipts in them too.
-        def get_no_receipt(txn):
+        def get_no_receipt(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[str, str, int, str, bool]]:
             sql = (
                 "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
                 "   ep.highlight "
@@ -318,19 +357,19 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
             txn.execute(sql, args)
-            return txn.fetchall()
+            return txn.fetchall()  # type: ignore[return-value]
 
         no_read_receipt = await self.db_pool.runInteraction(
             "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
         )
 
         notifs = [
-            {
-                "event_id": row[0],
-                "room_id": row[1],
-                "stream_ordering": row[2],
-                "actions": _deserialize_action(row[3], row[4]),
-            }
+            HttpPushAction(
+                event_id=row[0],
+                room_id=row[1],
+                stream_ordering=row[2],
+                actions=_deserialize_action(row[3], row[4]),
+            )
             for row in after_read_receipt + no_read_receipt
         ]
 
@@ -338,7 +377,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         # contain results from the first query, correctly ordered, followed
         # by results from the second query, but we want them all ordered
         # by stream_ordering, oldest first.
-        notifs.sort(key=lambda r: r["stream_ordering"])
+        notifs.sort(key=lambda r: r.stream_ordering)
 
         # Take only up to the limit. We have to stop at the limit because
         # one of the subqueries may have hit the limit.
@@ -368,7 +407,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         """
         # find rooms that have a read receipt in them and return the most recent
         # push actions
-        def get_after_receipt(txn):
+        def get_after_receipt(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[str, str, int, str, bool, int]]:
             sql = (
                 "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
                 "  ep.highlight, e.received_ts"
@@ -393,7 +434,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
             txn.execute(sql, args)
-            return txn.fetchall()
+            return txn.fetchall()  # type: ignore[return-value]
 
         after_read_receipt = await self.db_pool.runInteraction(
             "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
@@ -402,7 +443,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         # There are rooms with push actions in them but you don't have a read receipt in
         # them e.g. rooms you've been invited to, so get push actions for rooms which do
         # not have read receipts in them too.
-        def get_no_receipt(txn):
+        def get_no_receipt(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[str, str, int, str, bool, int]]:
             sql = (
                 "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
                 "   ep.highlight, e.received_ts"
@@ -422,7 +465,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
             txn.execute(sql, args)
-            return txn.fetchall()
+            return txn.fetchall()  # type: ignore[return-value]
 
         no_read_receipt = await self.db_pool.runInteraction(
             "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
@@ -430,13 +473,13 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
         # Make a list of dicts from the two sets of results.
         notifs = [
-            {
-                "event_id": row[0],
-                "room_id": row[1],
-                "stream_ordering": row[2],
-                "actions": _deserialize_action(row[3], row[4]),
-                "received_ts": row[5],
-            }
+            EmailPushAction(
+                event_id=row[0],
+                room_id=row[1],
+                stream_ordering=row[2],
+                actions=_deserialize_action(row[3], row[4]),
+                received_ts=row[5],
+            )
             for row in after_read_receipt + no_read_receipt
         ]
 
@@ -444,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         # contain results from the first query, correctly ordered, followed
         # by results from the second query, but we want them all ordered
         # by received_ts (most recent first)
-        notifs.sort(key=lambda r: -(r["received_ts"] or 0))
+        notifs.sort(key=lambda r: -(r.received_ts or 0))
 
         # Now return the first `limit`
         return notifs[:limit]
@@ -465,7 +508,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             True if there may be push to process, False if there definitely isn't.
         """
 
-        def _get_if_maybe_push_in_range_for_user_txn(txn):
+        def _get_if_maybe_push_in_range_for_user_txn(txn: LoggingTransaction) -> bool:
             sql = """
                 SELECT 1 FROM event_push_actions
                 WHERE user_id = ? AND stream_ordering > ? AND notif = 1
@@ -499,19 +542,21 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
         # This is a helper function for generating the necessary tuple that
         # can be used to insert into the `event_push_actions_staging` table.
-        def _gen_entry(user_id, actions):
+        def _gen_entry(
+            user_id: str, actions: List[Union[dict, str]]
+        ) -> Tuple[str, str, str, int, int, int]:
             is_highlight = 1 if _action_has_highlight(actions) else 0
             notif = 1 if "notify" in actions else 0
             return (
                 event_id,  # event_id column
                 user_id,  # user_id column
-                _serialize_action(actions, is_highlight),  # actions column
+                _serialize_action(actions, bool(is_highlight)),  # actions column
                 notif,  # notif column
                 is_highlight,  # highlight column
                 int(count_as_unread),  # unread column
             )
 
-        def _add_push_actions_to_staging_txn(txn):
+        def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None:
             # We don't use simple_insert_many here to avoid the overhead
             # of generating lists of dicts.
 
@@ -539,12 +584,11 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         """
 
         try:
-            res = await self.db_pool.simple_delete(
+            await self.db_pool.simple_delete(
                 table="event_push_actions_staging",
                 keyvalues={"event_id": event_id},
                 desc="remove_push_actions_from_staging",
             )
-            return res
         except Exception:
             # this method is called from an exception handler, so propagating
             # another exception here really isn't helpful - there's nothing
@@ -597,7 +641,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         )
 
     @staticmethod
-    def _find_first_stream_ordering_after_ts_txn(txn, ts):
+    def _find_first_stream_ordering_after_ts_txn(
+        txn: LoggingTransaction, ts: int
+    ) -> int:
         """
         Find the stream_ordering of the first event that was received on or
         after a given timestamp. This is relatively slow as there is no index
@@ -609,14 +655,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         stream_ordering
 
         Args:
-            txn (twisted.enterprise.adbapi.Transaction):
-            ts (int): timestamp to search for
+            txn:
+            ts: timestamp to search for
 
         Returns:
-            int: stream ordering
+            The stream ordering
         """
         txn.execute("SELECT MAX(stream_ordering) FROM events")
-        max_stream_ordering = txn.fetchone()[0]
+        max_stream_ordering = txn.fetchone()[0]  # type: ignore[index]
 
         if max_stream_ordering is None:
             return 0
@@ -672,8 +718,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
         return range_end
 
-    async def get_time_of_last_push_action_before(self, stream_ordering):
-        def f(txn):
+    async def get_time_of_last_push_action_before(
+        self, stream_ordering: int
+    ) -> Optional[int]:
+        def f(txn: LoggingTransaction) -> Optional[Tuple[int]]:
             sql = (
                 "SELECT e.received_ts"
                 " FROM event_push_actions AS ep"
@@ -683,7 +731,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 " LIMIT 1"
             )
             txn.execute(sql, (stream_ordering,))
-            return txn.fetchone()
+            return txn.fetchone()  # type: ignore[return-value]
 
         result = await self.db_pool.runInteraction(
             "get_time_of_last_push_action_before", f
@@ -691,7 +739,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         return result[0] if result else None
 
     @wrap_as_background_process("rotate_notifs")
-    async def _rotate_notifs(self):
+    async def _rotate_notifs(self) -> None:
         if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
             return
         self._doing_notif_rotation = True
@@ -709,7 +757,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         finally:
             self._doing_notif_rotation = False
 
-    def _rotate_notifs_txn(self, txn):
+    def _rotate_notifs_txn(self, txn: LoggingTransaction) -> bool:
         """Archives older notifications into event_push_summary. Returns whether
         the archiving process has caught up or not.
         """
@@ -734,6 +782,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         stream_row = txn.fetchone()
         if stream_row:
             (offset_stream_ordering,) = stream_row
+            assert self.stream_ordering_day_ago is not None
             rotate_to_stream_ordering = min(
                 self.stream_ordering_day_ago, offset_stream_ordering
             )
@@ -749,7 +798,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         # We have caught up iff we were limited by `stream_ordering_day_ago`
         return caught_up
 
-    def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
+    def _rotate_notifs_before_txn(
+        self, txn: LoggingTransaction, rotate_to_stream_ordering: int
+    ) -> None:
         old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
             txn,
             table="event_push_summary_stream_ordering",
@@ -870,8 +921,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         )
 
     def _remove_old_push_actions_before_txn(
-        self, txn, room_id, user_id, stream_ordering
-    ):
+        self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
+    ) -> None:
         """
         Purges old push actions for a user and room before a given
         stream_ordering.
@@ -943,9 +994,15 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         )
 
     async def get_push_actions_for_user(
-        self, user_id, before=None, limit=50, only_highlight=False
-    ):
-        def f(txn):
+        self,
+        user_id: str,
+        before: Optional[str] = None,
+        limit: int = 50,
+        only_highlight: bool = False,
+    ) -> List[UserPushAction]:
+        def f(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[str, str, int, int, str, bool, str, int]]:
             before_clause = ""
             if before:
                 before_clause = "AND epa.stream_ordering < ?"
@@ -972,32 +1029,42 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
                 " LIMIT ?" % (before_clause,)
             )
             txn.execute(sql, args)
-            return self.db_pool.cursor_to_dict(txn)
+            return txn.fetchall()  # type: ignore[return-value]
 
         push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
-        for pa in push_actions:
-            pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
-        return push_actions
+        return [
+            UserPushAction(
+                event_id=row[0],
+                room_id=row[1],
+                stream_ordering=row[2],
+                actions=_deserialize_action(row[4], row[5]),
+                received_ts=row[7],
+                topological_ordering=row[3],
+                highlight=row[5],
+                profile_tag=row[6],
+            )
+            for row in push_actions
+        ]
 
 
-def _action_has_highlight(actions):
+def _action_has_highlight(actions: List[Union[dict, str]]) -> bool:
     for action in actions:
-        try:
-            if action.get("set_tweak", None) == "highlight":
-                return action.get("value", True)
-        except AttributeError:
-            pass
+        if not isinstance(action, dict):
+            continue
+
+        if action.get("set_tweak", None) == "highlight":
+            return action.get("value", True)
 
     return False
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class _EventPushSummary:
     """Summary of pending event push actions for a given user in a given room.
     Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
     """
 
-    unread_count = attr.ib(type=int)
-    stream_ordering = attr.ib(type=int)
-    old_user_id = attr.ib(type=str)
-    notif_count = attr.ib(type=int)
+    unread_count: int
+    stream_ordering: int
+    old_user_id: str
+    notif_count: int
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index b25a06b427..eca6a443af 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -20,6 +20,7 @@ from synapse.api.room_versions import RoomVersions
 from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
 from synapse.handlers.room import RoomEventSource
 from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.storage.databases.main.event_push_actions import NotifCounts
 from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
 from synapse.types import PersistedEventPosition
 
@@ -166,7 +167,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         self.check(
             "get_unread_event_push_actions_by_room_for_user",
             [ROOM_ID, USER_ID_2, event1.event_id],
-            {"highlight_count": 0, "unread_count": 0, "notify_count": 0},
+            NotifCounts(highlight_count=0, unread_count=0, notify_count=0),
         )
 
         self.persist(
@@ -179,7 +180,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         self.check(
             "get_unread_event_push_actions_by_room_for_user",
             [ROOM_ID, USER_ID_2, event1.event_id],
-            {"highlight_count": 0, "unread_count": 0, "notify_count": 1},
+            NotifCounts(highlight_count=0, unread_count=0, notify_count=1),
         )
 
         self.persist(
@@ -194,7 +195,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         self.check(
             "get_unread_event_push_actions_by_room_for_user",
             [ROOM_ID, USER_ID_2, event1.event_id],
-            {"highlight_count": 1, "unread_count": 0, "notify_count": 2},
+            NotifCounts(highlight_count=1, unread_count=0, notify_count=2),
         )
 
     def test_get_rooms_for_user_with_stream_ordering(self):
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index bb5939ba4a..738f3ad1dc 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -14,6 +14,8 @@
 
 from unittest.mock import Mock
 
+from synapse.storage.databases.main.event_push_actions import NotifCounts
+
 from tests.unittest import HomeserverTestCase
 
 USER_ID = "@user:example.com"
@@ -57,11 +59,11 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
             )
             self.assertEquals(
                 counts,
-                {
-                    "notify_count": noitf_count,
-                    "unread_count": 0,  # Unread counts are tested in the sync tests.
-                    "highlight_count": highlight_count,
-                },
+                NotifCounts(
+                    notify_count=noitf_count,
+                    unread_count=0,  # Unread counts are tested in the sync tests.
+                    highlight_count=highlight_count,
+                ),
             )
 
         def _inject_actions(stream, action):