summary refs log tree commit diff
path: root/synapse/storage/databases/main/event_push_actions.py
diff options
context:
space:
mode:
authorOlivier Wilkinson (reivilibre) <oliverw@matrix.org>2022-01-05 14:19:39 +0000
committerOlivier Wilkinson (reivilibre) <oliverw@matrix.org>2022-01-05 14:19:39 +0000
commit717a5c085a593f00b9454e0155e16f0466b77fd3 (patch)
treef92d46b057c88443443409a8fd53e5c749917bd9 /synapse/storage/databases/main/event_push_actions.py
parentMerge branch 'rav/no_bundle_aggregations_in_sync' into matrix-org-hotfixes (diff)
parentMention drop of support in changelog (diff)
downloadsynapse-717a5c085a593f00b9454e0155e16f0466b77fd3.tar.xz
Merge branch 'release-v1.50' into matrix-org-hotfixes
Diffstat (limited to 'synapse/storage/databases/main/event_push_actions.py')
-rw-r--r--synapse/storage/databases/main/event_push_actions.py273
1 files changed, 178 insertions, 95 deletions
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py

index 3efdd0c920..a98e6b2593 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py
@@ -13,14 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast import attr -from typing_extensions import TypedDict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -30,29 +33,64 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}] -DEFAULT_HIGHLIGHT_ACTION = [ +DEFAULT_NOTIF_ACTION: List[Union[dict, str]] = [ + "notify", + {"set_tweak": "highlight", "value": False}, +] +DEFAULT_HIGHLIGHT_ACTION: List[Union[dict, str]] = [ "notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"}, ] -class BasePushAction(TypedDict): - event_id: str - actions: List[Union[dict, str]] - +@attr.s(slots=True, frozen=True, auto_attribs=True) +class HttpPushAction: + """ + HttpPushAction instances include the information used to generate HTTP + requests to a push gateway. + """ -class HttpPushAction(BasePushAction): + event_id: str room_id: str stream_ordering: int + actions: List[Union[dict, str]] +@attr.s(slots=True, frozen=True, auto_attribs=True) class EmailPushAction(HttpPushAction): + """ + EmailPushAction instances include the information used to render an email + push notification. + """ + received_ts: Optional[int] -def _serialize_action(actions, is_highlight): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class UserPushAction(EmailPushAction): + """ + UserPushAction instances include the necessary information to respond to + /notifications requests. + """ + + topological_ordering: int + highlight: bool + profile_tag: str + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class NotifCounts: + """ + The per-user, per-room count of notifications. Used by sync and push. + """ + + notify_count: int + unread_count: int + highlight_count: int + + +def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str: """Custom serializer for actions. This allows us to "compress" common actions. We use the fact that most users have the same actions for notifs (and for @@ -70,7 +108,7 @@ def _serialize_action(actions, is_highlight): return json_encoder.encode(actions) -def _deserialize_action(actions, is_highlight): +def _deserialize_action(actions: str, is_highlight: bool) -> List[Union[dict, str]]: """Custom deserializer for actions. This allows us to "compress" common actions""" if actions: return db_to_json(actions) @@ -82,12 +120,17 @@ def _deserialize_action(actions, is_highlight): class EventPushActionsWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) # These get correctly set by _find_stream_orderings_for_times_txn - self.stream_ordering_month_ago = None - self.stream_ordering_day_ago = None + self.stream_ordering_month_ago: Optional[int] = None + self.stream_ordering_day_ago: Optional[int] = None cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn") self._find_stream_orderings_for_times_txn(cur) @@ -111,7 +154,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): room_id: str, user_id: str, last_read_event_id: Optional[str], - ) -> Dict[str, int]: + ) -> NotifCounts: """Get the notification count, the highlight count and the unread message count for a given user in a given room after the given read receipt. @@ -140,15 +183,15 @@ class EventPushActionsWorkerStore(SQLBaseStore): def _get_unread_counts_by_receipt_txn( self, - txn, - room_id, - user_id, - last_read_event_id, - ): + txn: LoggingTransaction, + room_id: str, + user_id: str, + last_read_event_id: Optional[str], + ) -> NotifCounts: stream_ordering = None if last_read_event_id is not None: - stream_ordering = self.get_stream_id_for_event_txn( + stream_ordering = self.get_stream_id_for_event_txn( # type: ignore[attr-defined] txn, last_read_event_id, allow_none=True, @@ -166,13 +209,15 @@ class EventPushActionsWorkerStore(SQLBaseStore): retcol="event_id", ) - stream_ordering = self.get_stream_id_for_event_txn(txn, event_id) + stream_ordering = self.get_stream_id_for_event_txn(txn, event_id) # type: ignore[attr-defined] return self._get_unread_counts_by_pos_txn( txn, room_id, user_id, stream_ordering ) - def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering): + def _get_unread_counts_by_pos_txn( + self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int + ) -> NotifCounts: sql = ( "SELECT" " COUNT(CASE WHEN notif = 1 THEN 1 END)," @@ -210,16 +255,16 @@ class EventPushActionsWorkerStore(SQLBaseStore): # for this row. unread_count += row[1] - return { - "notify_count": notif_count, - "unread_count": unread_count, - "highlight_count": highlight_count, - } + return NotifCounts( + notify_count=notif_count, + unread_count=unread_count, + highlight_count=highlight_count, + ) async def get_push_action_users_in_range( - self, min_stream_ordering, max_stream_ordering - ): - def f(txn): + self, min_stream_ordering: int, max_stream_ordering: int + ) -> List[str]: + def f(txn: LoggingTransaction) -> List[str]: sql = ( "SELECT DISTINCT(user_id) FROM event_push_actions WHERE" " stream_ordering >= ? AND stream_ordering <= ? AND notif = 1" @@ -227,8 +272,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, (min_stream_ordering, max_stream_ordering)) return [r[0] for r in txn] - ret = await self.db_pool.runInteraction("get_push_action_users_in_range", f) - return ret + return await self.db_pool.runInteraction("get_push_action_users_in_range", f) async def get_unread_push_actions_for_user_in_range_for_http( self, @@ -254,7 +298,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): """ # find rooms that have a read receipt in them and return the next # push actions - def get_after_receipt(txn): + def get_after_receipt( + txn: LoggingTransaction, + ) -> List[Tuple[str, str, int, str, bool]]: # find rooms that have a read receipt in them and return the next # push actions sql = ( @@ -280,7 +326,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) - return txn.fetchall() + return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) after_read_receipt = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt @@ -289,7 +335,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): # There are rooms with push actions in them but you don't have a read receipt in # them e.g. rooms you've been invited to, so get push actions for rooms which do # not have read receipts in them too. - def get_no_receipt(txn): + def get_no_receipt( + txn: LoggingTransaction, + ) -> List[Tuple[str, str, int, str, bool]]: sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," " ep.highlight " @@ -309,19 +357,19 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) - return txn.fetchall() + return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) no_read_receipt = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt ) notifs = [ - { - "event_id": row[0], - "room_id": row[1], - "stream_ordering": row[2], - "actions": _deserialize_action(row[3], row[4]), - } + HttpPushAction( + event_id=row[0], + room_id=row[1], + stream_ordering=row[2], + actions=_deserialize_action(row[3], row[4]), + ) for row in after_read_receipt + no_read_receipt ] @@ -329,7 +377,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): # contain results from the first query, correctly ordered, followed # by results from the second query, but we want them all ordered # by stream_ordering, oldest first. - notifs.sort(key=lambda r: r["stream_ordering"]) + notifs.sort(key=lambda r: r.stream_ordering) # Take only up to the limit. We have to stop at the limit because # one of the subqueries may have hit the limit. @@ -359,7 +407,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): """ # find rooms that have a read receipt in them and return the most recent # push actions - def get_after_receipt(txn): + def get_after_receipt( + txn: LoggingTransaction, + ) -> List[Tuple[str, str, int, str, bool, int]]: sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," " ep.highlight, e.received_ts" @@ -384,7 +434,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) - return txn.fetchall() + return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) after_read_receipt = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt @@ -393,7 +443,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): # There are rooms with push actions in them but you don't have a read receipt in # them e.g. rooms you've been invited to, so get push actions for rooms which do # not have read receipts in them too. - def get_no_receipt(txn): + def get_no_receipt( + txn: LoggingTransaction, + ) -> List[Tuple[str, str, int, str, bool, int]]: sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," " ep.highlight, e.received_ts" @@ -413,7 +465,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) - return txn.fetchall() + return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) no_read_receipt = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt @@ -421,13 +473,13 @@ class EventPushActionsWorkerStore(SQLBaseStore): # Make a list of dicts from the two sets of results. notifs = [ - { - "event_id": row[0], - "room_id": row[1], - "stream_ordering": row[2], - "actions": _deserialize_action(row[3], row[4]), - "received_ts": row[5], - } + EmailPushAction( + event_id=row[0], + room_id=row[1], + stream_ordering=row[2], + actions=_deserialize_action(row[3], row[4]), + received_ts=row[5], + ) for row in after_read_receipt + no_read_receipt ] @@ -435,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): # contain results from the first query, correctly ordered, followed # by results from the second query, but we want them all ordered # by received_ts (most recent first) - notifs.sort(key=lambda r: -(r["received_ts"] or 0)) + notifs.sort(key=lambda r: -(r.received_ts or 0)) # Now return the first `limit` return notifs[:limit] @@ -456,7 +508,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): True if there may be push to process, False if there definitely isn't. """ - def _get_if_maybe_push_in_range_for_user_txn(txn): + def _get_if_maybe_push_in_range_for_user_txn(txn: LoggingTransaction) -> bool: sql = """ SELECT 1 FROM event_push_actions WHERE user_id = ? AND stream_ordering > ? AND notif = 1 @@ -490,19 +542,21 @@ class EventPushActionsWorkerStore(SQLBaseStore): # This is a helper function for generating the necessary tuple that # can be used to insert into the `event_push_actions_staging` table. - def _gen_entry(user_id, actions): + def _gen_entry( + user_id: str, actions: List[Union[dict, str]] + ) -> Tuple[str, str, str, int, int, int]: is_highlight = 1 if _action_has_highlight(actions) else 0 notif = 1 if "notify" in actions else 0 return ( event_id, # event_id column user_id, # user_id column - _serialize_action(actions, is_highlight), # actions column + _serialize_action(actions, bool(is_highlight)), # actions column notif, # notif column is_highlight, # highlight column int(count_as_unread), # unread column ) - def _add_push_actions_to_staging_txn(txn): + def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None: # We don't use simple_insert_many here to avoid the overhead # of generating lists of dicts. @@ -530,12 +584,11 @@ class EventPushActionsWorkerStore(SQLBaseStore): """ try: - res = await self.db_pool.simple_delete( + await self.db_pool.simple_delete( table="event_push_actions_staging", keyvalues={"event_id": event_id}, desc="remove_push_actions_from_staging", ) - return res except Exception: # this method is called from an exception handler, so propagating # another exception here really isn't helpful - there's nothing @@ -588,7 +641,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) @staticmethod - def _find_first_stream_ordering_after_ts_txn(txn, ts): + def _find_first_stream_ordering_after_ts_txn( + txn: LoggingTransaction, ts: int + ) -> int: """ Find the stream_ordering of the first event that was received on or after a given timestamp. This is relatively slow as there is no index @@ -600,14 +655,14 @@ class EventPushActionsWorkerStore(SQLBaseStore): stream_ordering Args: - txn (twisted.enterprise.adbapi.Transaction): - ts (int): timestamp to search for + txn: + ts: timestamp to search for Returns: - int: stream ordering + The stream ordering """ txn.execute("SELECT MAX(stream_ordering) FROM events") - max_stream_ordering = txn.fetchone()[0] + max_stream_ordering = cast(Tuple[Optional[int]], txn.fetchone())[0] if max_stream_ordering is None: return 0 @@ -663,8 +718,10 @@ class EventPushActionsWorkerStore(SQLBaseStore): return range_end - async def get_time_of_last_push_action_before(self, stream_ordering): - def f(txn): + async def get_time_of_last_push_action_before( + self, stream_ordering: int + ) -> Optional[int]: + def f(txn: LoggingTransaction) -> Optional[Tuple[int]]: sql = ( "SELECT e.received_ts" " FROM event_push_actions AS ep" @@ -674,7 +731,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): " LIMIT 1" ) txn.execute(sql, (stream_ordering,)) - return txn.fetchone() + return cast(Optional[Tuple[int]], txn.fetchone()) result = await self.db_pool.runInteraction( "get_time_of_last_push_action_before", f @@ -682,7 +739,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): return result[0] if result else None @wrap_as_background_process("rotate_notifs") - async def _rotate_notifs(self): + async def _rotate_notifs(self) -> None: if self._doing_notif_rotation or self.stream_ordering_day_ago is None: return self._doing_notif_rotation = True @@ -700,7 +757,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): finally: self._doing_notif_rotation = False - def _rotate_notifs_txn(self, txn): + def _rotate_notifs_txn(self, txn: LoggingTransaction) -> bool: """Archives older notifications into event_push_summary. Returns whether the archiving process has caught up or not. """ @@ -725,6 +782,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): stream_row = txn.fetchone() if stream_row: (offset_stream_ordering,) = stream_row + assert self.stream_ordering_day_ago is not None rotate_to_stream_ordering = min( self.stream_ordering_day_ago, offset_stream_ordering ) @@ -740,7 +798,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): # We have caught up iff we were limited by `stream_ordering_day_ago` return caught_up - def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): + def _rotate_notifs_before_txn( + self, txn: LoggingTransaction, rotate_to_stream_ordering: int + ) -> None: old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", @@ -861,8 +921,8 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) def _remove_old_push_actions_before_txn( - self, txn, room_id, user_id, stream_ordering - ): + self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int + ) -> None: """ Purges old push actions for a user and room before a given stream_ordering. @@ -910,7 +970,12 @@ class EventPushActionsWorkerStore(SQLBaseStore): class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -929,9 +994,15 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ) async def get_push_actions_for_user( - self, user_id, before=None, limit=50, only_highlight=False - ): - def f(txn): + self, + user_id: str, + before: Optional[str] = None, + limit: int = 50, + only_highlight: bool = False, + ) -> List[UserPushAction]: + def f( + txn: LoggingTransaction, + ) -> List[Tuple[str, str, int, int, str, bool, str, int]]: before_clause = "" if before: before_clause = "AND epa.stream_ordering < ?" @@ -958,32 +1029,44 @@ class EventPushActionsStore(EventPushActionsWorkerStore): " LIMIT ?" % (before_clause,) ) txn.execute(sql, args) - return self.db_pool.cursor_to_dict(txn) + return cast( + List[Tuple[str, str, int, int, str, bool, str, int]], txn.fetchall() + ) push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f) - for pa in push_actions: - pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) - return push_actions + return [ + UserPushAction( + event_id=row[0], + room_id=row[1], + stream_ordering=row[2], + actions=_deserialize_action(row[4], row[5]), + received_ts=row[7], + topological_ordering=row[3], + highlight=row[5], + profile_tag=row[6], + ) + for row in push_actions + ] -def _action_has_highlight(actions): +def _action_has_highlight(actions: List[Union[dict, str]]) -> bool: for action in actions: - try: - if action.get("set_tweak", None) == "highlight": - return action.get("value", True) - except AttributeError: - pass + if not isinstance(action, dict): + continue + + if action.get("set_tweak", None) == "highlight": + return action.get("value", True) return False -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class _EventPushSummary: """Summary of pending event push actions for a given user in a given room. Used in _rotate_notifs_before_txn to manipulate results from event_push_actions. """ - unread_count = attr.ib(type=int) - stream_ordering = attr.ib(type=int) - old_user_id = attr.ib(type=str) - notif_count = attr.ib(type=int) + unread_count: int + stream_ordering: int + old_user_id: str + notif_count: int