summary refs log tree commit diff
path: root/synapse/storage/databases/main/event_push_actions.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/event_push_actions.py')
-rw-r--r--synapse/storage/databases/main/event_push_actions.py258
1 files changed, 175 insertions, 83 deletions
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py

index e8834b2162..001d06378d 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py
@@ -15,7 +15,9 @@ # limitations under the License. import logging -from typing import List +from typing import Dict, List, Optional, Tuple, Union + +import attr from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json @@ -88,8 +90,26 @@ class EventPushActionsWorkerStore(SQLBaseStore): @cached(num_args=3, tree=True, max_entries=5000) async def get_unread_event_push_actions_by_room_for_user( - self, room_id, user_id, last_read_event_id - ): + self, room_id: str, user_id: str, last_read_event_id: Optional[str], + ) -> Dict[str, int]: + """Get the notification count, the highlight count and the unread message count + for a given user in a given room after the given read receipt. + + Note that this function assumes the user to be a current member of the room, + since it's either called by the sync handler to handle joined room entries, or by + the HTTP pusher to calculate the badge of unread joined rooms. + + Args: + room_id: The room to retrieve the counts in. + user_id: The user to retrieve the counts for. + last_read_event_id: The event associated with the latest read receipt for + this user in this room. None if no receipt for this user in this room. + + Returns + A dict containing the counts mentioned earlier in this docstring, + respectively under the keys "notify_count", "highlight_count" and + "unread_count". + """ return await self.db_pool.runInteraction( "get_unread_event_push_actions_by_room", self._get_unread_counts_by_receipt_txn, @@ -99,69 +119,71 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) def _get_unread_counts_by_receipt_txn( - self, txn, room_id, user_id, last_read_event_id + self, txn, room_id, user_id, last_read_event_id, ): - sql = ( - "SELECT stream_ordering" - " FROM events" - " WHERE room_id = ? AND event_id = ?" - ) - txn.execute(sql, (room_id, last_read_event_id)) - results = txn.fetchall() - if len(results) == 0: - return {"notify_count": 0, "highlight_count": 0} + stream_ordering = None + + if last_read_event_id is not None: + stream_ordering = self.get_stream_id_for_event_txn( + txn, last_read_event_id, allow_none=True, + ) + + if stream_ordering is None: + # Either last_read_event_id is None, or it's an event we don't have (e.g. + # because it's been purged), in which case retrieve the stream ordering for + # the latest membership event from this user in this room (which we assume is + # a join). + event_id = self.db_pool.simple_select_one_onecol_txn( + txn=txn, + table="local_current_membership", + keyvalues={"room_id": room_id, "user_id": user_id}, + retcol="event_id", + ) - stream_ordering = results[0][0] + stream_ordering = self.get_stream_id_for_event_txn(txn, event_id) return self._get_unread_counts_by_pos_txn( txn, room_id, user_id, stream_ordering ) def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering): - - # First get number of notifications. - # We don't need to put a notif=1 clause as all rows always have - # notif=1 sql = ( - "SELECT count(*)" + "SELECT" + " COUNT(CASE WHEN notif = 1 THEN 1 END)," + " COUNT(CASE WHEN highlight = 1 THEN 1 END)," + " COUNT(CASE WHEN unread = 1 THEN 1 END)" " FROM event_push_actions ea" - " WHERE" - " user_id = ?" - " AND room_id = ?" - " AND stream_ordering > ?" + " WHERE user_id = ?" + " AND room_id = ?" + " AND stream_ordering > ?" ) txn.execute(sql, (user_id, room_id, stream_ordering)) row = txn.fetchone() - notify_count = row[0] if row else 0 + + (notif_count, highlight_count, unread_count) = (0, 0, 0) + + if row: + (notif_count, highlight_count, unread_count) = row txn.execute( """ - SELECT notif_count FROM event_push_summary - WHERE room_id = ? AND user_id = ? AND stream_ordering > ? - """, + SELECT notif_count, unread_count FROM event_push_summary + WHERE room_id = ? AND user_id = ? AND stream_ordering > ? + """, (room_id, user_id, stream_ordering), ) - rows = txn.fetchall() - if rows: - notify_count += rows[0][0] - - # Now get the number of highlights - sql = ( - "SELECT count(*)" - " FROM event_push_actions ea" - " WHERE" - " highlight = 1" - " AND user_id = ?" - " AND room_id = ?" - " AND stream_ordering > ?" - ) - - txn.execute(sql, (user_id, room_id, stream_ordering)) row = txn.fetchone() - highlight_count = row[0] if row else 0 - return {"notify_count": notify_count, "highlight_count": highlight_count} + if row: + notif_count += row[0] + unread_count += row[1] + + return { + "notify_count": notif_count, + "unread_count": unread_count, + "highlight_count": highlight_count, + } async def get_push_action_users_in_range( self, min_stream_ordering, max_stream_ordering @@ -222,6 +244,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" + " AND ep.notif = 1" " ORDER BY ep.stream_ordering ASC LIMIT ?" ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] @@ -250,6 +273,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" + " AND ep.notif = 1" " ORDER BY ep.stream_ordering ASC LIMIT ?" ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] @@ -324,6 +348,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" + " AND ep.notif = 1" " ORDER BY ep.stream_ordering DESC LIMIT ?" ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] @@ -352,6 +377,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" + " AND ep.notif = 1" " ORDER BY ep.stream_ordering DESC LIMIT ?" ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] @@ -383,62 +409,66 @@ class EventPushActionsWorkerStore(SQLBaseStore): # Now return the first `limit` return notifs[:limit] - def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering): + async def get_if_maybe_push_in_range_for_user( + self, user_id: str, min_stream_ordering: int + ) -> bool: """A fast check to see if there might be something to push for the user since the given stream ordering. May return false positives. Useful to know whether to bother starting a pusher on start up or not. Args: - user_id (str) - min_stream_ordering (int) + user_id + min_stream_ordering Returns: - Deferred[bool]: True if there may be push to process, False if - there definitely isn't. + True if there may be push to process, False if there definitely isn't. """ def _get_if_maybe_push_in_range_for_user_txn(txn): sql = """ SELECT 1 FROM event_push_actions - WHERE user_id = ? AND stream_ordering > ? + WHERE user_id = ? AND stream_ordering > ? AND notif = 1 LIMIT 1 """ txn.execute(sql, (user_id, min_stream_ordering)) return bool(txn.fetchone()) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_if_maybe_push_in_range_for_user", _get_if_maybe_push_in_range_for_user_txn, ) - async def add_push_actions_to_staging(self, event_id, user_id_actions): + async def add_push_actions_to_staging( + self, + event_id: str, + user_id_actions: Dict[str, List[Union[dict, str]]], + count_as_unread: bool, + ) -> None: """Add the push actions for the event to the push action staging area. Args: - event_id (str) - user_id_actions (dict[str, list[dict|str])]): A dictionary mapping - user_id to list of push actions, where an action can either be - a string or dict. - - Returns: - Deferred + event_id + user_id_actions: A mapping of user_id to list of push actions, where + an action can either be a string or dict. + count_as_unread: Whether this event should increment unread counts. """ - if not user_id_actions: return # This is a helper function for generating the necessary tuple that - # can be used to inert into the `event_push_actions_staging` table. + # can be used to insert into the `event_push_actions_staging` table. def _gen_entry(user_id, actions): is_highlight = 1 if _action_has_highlight(actions) else 0 + notif = 1 if "notify" in actions else 0 return ( event_id, # event_id column user_id, # user_id column _serialize_action(actions, is_highlight), # actions column - 1, # notif column + notif, # notif column is_highlight, # highlight column + int(count_as_unread), # unread column ) def _add_push_actions_to_staging_txn(txn): @@ -447,8 +477,8 @@ class EventPushActionsWorkerStore(SQLBaseStore): sql = """ INSERT INTO event_push_actions_staging - (event_id, user_id, actions, notif, highlight) - VALUES (?, ?, ?, ?, ?) + (event_id, user_id, actions, notif, highlight, unread) + VALUES (?, ?, ?, ?, ?, ?) """ txn.executemany( @@ -507,7 +537,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): "Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago ) - def find_first_stream_ordering_after_ts(self, ts): + async def find_first_stream_ordering_after_ts(self, ts: int) -> int: """Gets the stream ordering corresponding to a given timestamp. Specifically, finds the stream_ordering of the first event that was @@ -516,13 +546,12 @@ class EventPushActionsWorkerStore(SQLBaseStore): relatively slow. Args: - ts (int): timestamp in millis + ts: timestamp in millis Returns: - Deferred[int]: stream ordering of the first event received on/after - the timestamp + stream ordering of the first event received on/after the timestamp """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "_find_first_stream_ordering_after_ts_txn", self._find_first_stream_ordering_after_ts_txn, ts, @@ -813,24 +842,63 @@ class EventPushActionsStore(EventPushActionsWorkerStore): # Calculate the new counts that should be upserted into event_push_summary sql = """ SELECT user_id, room_id, - coalesce(old.notif_count, 0) + upd.notif_count, + coalesce(old.%s, 0) + upd.cnt, upd.stream_ordering, old.user_id FROM ( - SELECT user_id, room_id, count(*) as notif_count, + SELECT user_id, room_id, count(*) as cnt, max(stream_ordering) as stream_ordering FROM event_push_actions WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0 + AND %s = 1 GROUP BY user_id, room_id ) AS upd LEFT JOIN event_push_summary AS old USING (user_id, room_id) """ - txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering)) - rows = txn.fetchall() + # First get the count of unread messages. + txn.execute( + sql % ("unread_count", "unread"), + (old_rotate_stream_ordering, rotate_to_stream_ordering), + ) + + # We need to merge results from the two requests (the one that retrieves the + # unread count and the one that retrieves the notifications count) into a single + # object because we might not have the same amount of rows in each of them. To do + # this, we use a dict indexed on the user ID and room ID to make it easier to + # populate. + summaries = {} # type: Dict[Tuple[str, str], _EventPushSummary] + for row in txn: + summaries[(row[0], row[1])] = _EventPushSummary( + unread_count=row[2], + stream_ordering=row[3], + old_user_id=row[4], + notif_count=0, + ) + + # Then get the count of notifications. + txn.execute( + sql % ("notif_count", "notif"), + (old_rotate_stream_ordering, rotate_to_stream_ordering), + ) + + for row in txn: + if (row[0], row[1]) in summaries: + summaries[(row[0], row[1])].notif_count = row[2] + else: + # Because the rules on notifying are different than the rules on marking + # a message unread, we might end up with messages that notify but aren't + # marked unread, so we might not have a summary for this (user, room) + # tuple to complete. + summaries[(row[0], row[1])] = _EventPushSummary( + unread_count=0, + stream_ordering=row[3], + old_user_id=row[4], + notif_count=row[2], + ) - logger.info("Rotating notifications, handling %d rows", len(rows)) + logger.info("Rotating notifications, handling %d rows", len(summaries)) # If the `old.user_id` above is NULL then we know there isn't already an # entry in the table, so we simply insert it. Otherwise we update the @@ -840,22 +908,34 @@ class EventPushActionsStore(EventPushActionsWorkerStore): table="event_push_summary", values=[ { - "user_id": row[0], - "room_id": row[1], - "notif_count": row[2], - "stream_ordering": row[3], + "user_id": user_id, + "room_id": room_id, + "notif_count": summary.notif_count, + "unread_count": summary.unread_count, + "stream_ordering": summary.stream_ordering, } - for row in rows - if row[4] is None + for ((user_id, room_id), summary) in summaries.items() + if summary.old_user_id is None ], ) txn.executemany( """ - UPDATE event_push_summary SET notif_count = ?, stream_ordering = ? + UPDATE event_push_summary + SET notif_count = ?, unread_count = ?, stream_ordering = ? WHERE user_id = ? AND room_id = ? """, - ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None), + ( + ( + summary.notif_count, + summary.unread_count, + summary.stream_ordering, + user_id, + room_id, + ) + for ((user_id, room_id), summary) in summaries.items() + if summary.old_user_id is not None + ), ) txn.execute( @@ -881,3 +961,15 @@ def _action_has_highlight(actions): pass return False + + +@attr.s +class _EventPushSummary: + """Summary of pending event push actions for a given user in a given room. + Used in _rotate_notifs_before_txn to manipulate results from event_push_actions. + """ + + unread_count = attr.ib(type=int) + stream_ordering = attr.ib(type=int) + old_user_id = attr.ib(type=str) + notif_count = attr.ib(type=int)