diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index eacff3e432..98ea0e884c 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -16,7 +16,6 @@ import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import attr
-from typing_extensions import TypedDict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -34,29 +33,64 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}]
-DEFAULT_HIGHLIGHT_ACTION = [
+DEFAULT_NOTIF_ACTION: List[Union[dict, str]] = [
+ "notify",
+ {"set_tweak": "highlight", "value": False},
+]
+DEFAULT_HIGHLIGHT_ACTION: List[Union[dict, str]] = [
"notify",
{"set_tweak": "sound", "value": "default"},
{"set_tweak": "highlight"},
]
-class BasePushAction(TypedDict):
- event_id: str
- actions: List[Union[dict, str]]
-
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class HttpPushAction:
+ """
+ HttpPushAction instances include the information used to generate HTTP
+ requests to a push gateway.
+ """
-class HttpPushAction(BasePushAction):
+ event_id: str
room_id: str
stream_ordering: int
+ actions: List[Union[dict, str]]
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class EmailPushAction(HttpPushAction):
+ """
+ EmailPushAction instances include the information used to render an email
+ push notification.
+ """
+
received_ts: Optional[int]
-def _serialize_action(actions, is_highlight):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class UserPushAction(EmailPushAction):
+ """
+ UserPushAction instances include the necessary information to respond to
+ /notifications requests.
+ """
+
+ topological_ordering: int
+ highlight: bool
+ profile_tag: str
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class NotifCounts:
+ """
+ The per-user, per-room count of notifications. Used by sync and push.
+ """
+
+ notify_count: int
+ unread_count: int
+ highlight_count: int
+
+
+def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str:
"""Custom serializer for actions. This allows us to "compress" common actions.
We use the fact that most users have the same actions for notifs (and for
@@ -74,7 +108,7 @@ def _serialize_action(actions, is_highlight):
return json_encoder.encode(actions)
-def _deserialize_action(actions, is_highlight):
+def _deserialize_action(actions: str, is_highlight: bool) -> List[Union[dict, str]]:
"""Custom deserializer for actions. This allows us to "compress" common actions"""
if actions:
return db_to_json(actions)
@@ -95,8 +129,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
super().__init__(database, db_conn, hs)
# These get correctly set by _find_stream_orderings_for_times_txn
- self.stream_ordering_month_ago = None
- self.stream_ordering_day_ago = None
+ self.stream_ordering_month_ago: Optional[int] = None
+ self.stream_ordering_day_ago: Optional[int] = None
cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn")
self._find_stream_orderings_for_times_txn(cur)
@@ -120,7 +154,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
room_id: str,
user_id: str,
last_read_event_id: Optional[str],
- ) -> Dict[str, int]:
+ ) -> NotifCounts:
"""Get the notification count, the highlight count and the unread message count
for a given user in a given room after the given read receipt.
@@ -149,15 +183,15 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def _get_unread_counts_by_receipt_txn(
self,
- txn,
- room_id,
- user_id,
- last_read_event_id,
- ):
+ txn: LoggingTransaction,
+ room_id: str,
+ user_id: str,
+ last_read_event_id: Optional[str],
+ ) -> NotifCounts:
stream_ordering = None
if last_read_event_id is not None:
- stream_ordering = self.get_stream_id_for_event_txn(
+ stream_ordering = self.get_stream_id_for_event_txn( # type: ignore[attr-defined]
txn,
last_read_event_id,
allow_none=True,
@@ -175,13 +209,15 @@ class EventPushActionsWorkerStore(SQLBaseStore):
retcol="event_id",
)
- stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
+ stream_ordering = self.get_stream_id_for_event_txn(txn, event_id) # type: ignore[attr-defined]
return self._get_unread_counts_by_pos_txn(
txn, room_id, user_id, stream_ordering
)
- def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
+ def _get_unread_counts_by_pos_txn(
+ self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
+ ) -> NotifCounts:
sql = (
"SELECT"
" COUNT(CASE WHEN notif = 1 THEN 1 END),"
@@ -219,16 +255,16 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# for this row.
unread_count += row[1]
- return {
- "notify_count": notif_count,
- "unread_count": unread_count,
- "highlight_count": highlight_count,
- }
+ return NotifCounts(
+ notify_count=notif_count,
+ unread_count=unread_count,
+ highlight_count=highlight_count,
+ )
async def get_push_action_users_in_range(
- self, min_stream_ordering, max_stream_ordering
- ):
- def f(txn):
+ self, min_stream_ordering: int, max_stream_ordering: int
+ ) -> List[str]:
+ def f(txn: LoggingTransaction) -> List[str]:
sql = (
"SELECT DISTINCT(user_id) FROM event_push_actions WHERE"
" stream_ordering >= ? AND stream_ordering <= ? AND notif = 1"
@@ -236,8 +272,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn]
- ret = await self.db_pool.runInteraction("get_push_action_users_in_range", f)
- return ret
+ return await self.db_pool.runInteraction("get_push_action_users_in_range", f)
async def get_unread_push_actions_for_user_in_range_for_http(
self,
@@ -263,7 +298,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"""
# find rooms that have a read receipt in them and return the next
# push actions
- def get_after_receipt(txn):
+ def get_after_receipt(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[str, str, int, str, bool]]:
# find rooms that have a read receipt in them and return the next
# push actions
sql = (
@@ -289,7 +326,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
- return txn.fetchall()
+ return txn.fetchall() # type: ignore[return-value]
after_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
@@ -298,7 +335,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# There are rooms with push actions in them but you don't have a read receipt in
# them e.g. rooms you've been invited to, so get push actions for rooms which do
# not have read receipts in them too.
- def get_no_receipt(txn):
+ def get_no_receipt(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[str, str, int, str, bool]]:
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" ep.highlight "
@@ -318,19 +357,19 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
- return txn.fetchall()
+ return txn.fetchall() # type: ignore[return-value]
no_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
)
notifs = [
- {
- "event_id": row[0],
- "room_id": row[1],
- "stream_ordering": row[2],
- "actions": _deserialize_action(row[3], row[4]),
- }
+ HttpPushAction(
+ event_id=row[0],
+ room_id=row[1],
+ stream_ordering=row[2],
+ actions=_deserialize_action(row[3], row[4]),
+ )
for row in after_read_receipt + no_read_receipt
]
@@ -338,7 +377,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# contain results from the first query, correctly ordered, followed
# by results from the second query, but we want them all ordered
# by stream_ordering, oldest first.
- notifs.sort(key=lambda r: r["stream_ordering"])
+ notifs.sort(key=lambda r: r.stream_ordering)
# Take only up to the limit. We have to stop at the limit because
# one of the subqueries may have hit the limit.
@@ -368,7 +407,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"""
# find rooms that have a read receipt in them and return the most recent
# push actions
- def get_after_receipt(txn):
+ def get_after_receipt(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[str, str, int, str, bool, int]]:
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" ep.highlight, e.received_ts"
@@ -393,7 +434,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
- return txn.fetchall()
+ return txn.fetchall() # type: ignore[return-value]
after_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
@@ -402,7 +443,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# There are rooms with push actions in them but you don't have a read receipt in
# them e.g. rooms you've been invited to, so get push actions for rooms which do
# not have read receipts in them too.
- def get_no_receipt(txn):
+ def get_no_receipt(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[str, str, int, str, bool, int]]:
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" ep.highlight, e.received_ts"
@@ -422,7 +465,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
- return txn.fetchall()
+ return txn.fetchall() # type: ignore[return-value]
no_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
@@ -430,13 +473,13 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# Make a list of dicts from the two sets of results.
notifs = [
- {
- "event_id": row[0],
- "room_id": row[1],
- "stream_ordering": row[2],
- "actions": _deserialize_action(row[3], row[4]),
- "received_ts": row[5],
- }
+ EmailPushAction(
+ event_id=row[0],
+ room_id=row[1],
+ stream_ordering=row[2],
+ actions=_deserialize_action(row[3], row[4]),
+ received_ts=row[5],
+ )
for row in after_read_receipt + no_read_receipt
]
@@ -444,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# contain results from the first query, correctly ordered, followed
# by results from the second query, but we want them all ordered
# by received_ts (most recent first)
- notifs.sort(key=lambda r: -(r["received_ts"] or 0))
+ notifs.sort(key=lambda r: -(r.received_ts or 0))
# Now return the first `limit`
return notifs[:limit]
@@ -465,7 +508,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
True if there may be push to process, False if there definitely isn't.
"""
- def _get_if_maybe_push_in_range_for_user_txn(txn):
+ def _get_if_maybe_push_in_range_for_user_txn(txn: LoggingTransaction) -> bool:
sql = """
SELECT 1 FROM event_push_actions
WHERE user_id = ? AND stream_ordering > ? AND notif = 1
@@ -499,19 +542,21 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# This is a helper function for generating the necessary tuple that
# can be used to insert into the `event_push_actions_staging` table.
- def _gen_entry(user_id, actions):
+ def _gen_entry(
+ user_id: str, actions: List[Union[dict, str]]
+ ) -> Tuple[str, str, str, int, int, int]:
is_highlight = 1 if _action_has_highlight(actions) else 0
notif = 1 if "notify" in actions else 0
return (
event_id, # event_id column
user_id, # user_id column
- _serialize_action(actions, is_highlight), # actions column
+ _serialize_action(actions, bool(is_highlight)), # actions column
notif, # notif column
is_highlight, # highlight column
int(count_as_unread), # unread column
)
- def _add_push_actions_to_staging_txn(txn):
+ def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None:
# We don't use simple_insert_many here to avoid the overhead
# of generating lists of dicts.
@@ -539,12 +584,11 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"""
try:
- res = await self.db_pool.simple_delete(
+ await self.db_pool.simple_delete(
table="event_push_actions_staging",
keyvalues={"event_id": event_id},
desc="remove_push_actions_from_staging",
)
- return res
except Exception:
# this method is called from an exception handler, so propagating
# another exception here really isn't helpful - there's nothing
@@ -597,7 +641,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
@staticmethod
- def _find_first_stream_ordering_after_ts_txn(txn, ts):
+ def _find_first_stream_ordering_after_ts_txn(
+ txn: LoggingTransaction, ts: int
+ ) -> int:
"""
Find the stream_ordering of the first event that was received on or
after a given timestamp. This is relatively slow as there is no index
@@ -609,14 +655,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
stream_ordering
Args:
- txn (twisted.enterprise.adbapi.Transaction):
- ts (int): timestamp to search for
+ txn:
+ ts: timestamp to search for
Returns:
- int: stream ordering
+ The stream ordering
"""
txn.execute("SELECT MAX(stream_ordering) FROM events")
- max_stream_ordering = txn.fetchone()[0]
+ max_stream_ordering = txn.fetchone()[0] # type: ignore[index]
if max_stream_ordering is None:
return 0
@@ -672,8 +718,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return range_end
- async def get_time_of_last_push_action_before(self, stream_ordering):
- def f(txn):
+ async def get_time_of_last_push_action_before(
+ self, stream_ordering: int
+ ) -> Optional[int]:
+ def f(txn: LoggingTransaction) -> Optional[Tuple[int]]:
sql = (
"SELECT e.received_ts"
" FROM event_push_actions AS ep"
@@ -683,7 +731,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" LIMIT 1"
)
txn.execute(sql, (stream_ordering,))
- return txn.fetchone()
+ return txn.fetchone() # type: ignore[return-value]
result = await self.db_pool.runInteraction(
"get_time_of_last_push_action_before", f
@@ -691,7 +739,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return result[0] if result else None
@wrap_as_background_process("rotate_notifs")
- async def _rotate_notifs(self):
+ async def _rotate_notifs(self) -> None:
if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
return
self._doing_notif_rotation = True
@@ -709,7 +757,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
finally:
self._doing_notif_rotation = False
- def _rotate_notifs_txn(self, txn):
+ def _rotate_notifs_txn(self, txn: LoggingTransaction) -> bool:
"""Archives older notifications into event_push_summary. Returns whether
the archiving process has caught up or not.
"""
@@ -734,6 +782,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
stream_row = txn.fetchone()
if stream_row:
(offset_stream_ordering,) = stream_row
+ assert self.stream_ordering_day_ago is not None
rotate_to_stream_ordering = min(
self.stream_ordering_day_ago, offset_stream_ordering
)
@@ -749,7 +798,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# We have caught up iff we were limited by `stream_ordering_day_ago`
return caught_up
- def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
+ def _rotate_notifs_before_txn(
+ self, txn: LoggingTransaction, rotate_to_stream_ordering: int
+ ) -> None:
old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
@@ -870,8 +921,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
def _remove_old_push_actions_before_txn(
- self, txn, room_id, user_id, stream_ordering
- ):
+ self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
+ ) -> None:
"""
Purges old push actions for a user and room before a given
stream_ordering.
@@ -943,9 +994,15 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
async def get_push_actions_for_user(
- self, user_id, before=None, limit=50, only_highlight=False
- ):
- def f(txn):
+ self,
+ user_id: str,
+ before: Optional[str] = None,
+ limit: int = 50,
+ only_highlight: bool = False,
+ ) -> List[UserPushAction]:
+ def f(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[str, str, int, int, str, bool, str, int]]:
before_clause = ""
if before:
before_clause = "AND epa.stream_ordering < ?"
@@ -972,32 +1029,42 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
" LIMIT ?" % (before_clause,)
)
txn.execute(sql, args)
- return self.db_pool.cursor_to_dict(txn)
+ return txn.fetchall() # type: ignore[return-value]
push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
- for pa in push_actions:
- pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
- return push_actions
+ return [
+ UserPushAction(
+ event_id=row[0],
+ room_id=row[1],
+ stream_ordering=row[2],
+ actions=_deserialize_action(row[4], row[5]),
+ received_ts=row[7],
+ topological_ordering=row[3],
+ highlight=row[5],
+ profile_tag=row[6],
+ )
+ for row in push_actions
+ ]
-def _action_has_highlight(actions):
+def _action_has_highlight(actions: List[Union[dict, str]]) -> bool:
for action in actions:
- try:
- if action.get("set_tweak", None) == "highlight":
- return action.get("value", True)
- except AttributeError:
- pass
+ if not isinstance(action, dict):
+ continue
+
+ if action.get("set_tweak", None) == "highlight":
+ return action.get("value", True)
return False
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class _EventPushSummary:
"""Summary of pending event push actions for a given user in a given room.
Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
"""
- unread_count = attr.ib(type=int)
- stream_ordering = attr.ib(type=int)
- old_user_id = attr.ib(type=str)
- notif_count = attr.ib(type=int)
+ unread_count: int
+ stream_ordering: int
+ old_user_id: str
+ notif_count: int
|