summary refs log tree commit diff
path: root/synapse/storage/event_push_actions.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/event_push_actions.py')
-rw-r--r--synapse/storage/event_push_actions.py213
1 files changed, 123 insertions, 90 deletions
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 6840320641..a729f3e067 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -31,7 +31,9 @@ logger = logging.getLogger(__name__)
 
 DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}]
 DEFAULT_HIGHLIGHT_ACTION = [
-    "notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"}
+    "notify",
+    {"set_tweak": "sound", "value": "default"},
+    {"set_tweak": "highlight"},
 ]
 
 
@@ -91,25 +93,26 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
     @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
     def get_unread_event_push_actions_by_room_for_user(
-            self, room_id, user_id, last_read_event_id
+        self, room_id, user_id, last_read_event_id
     ):
         ret = yield self.runInteraction(
             "get_unread_event_push_actions_by_room",
             self._get_unread_counts_by_receipt_txn,
-            room_id, user_id, last_read_event_id
+            room_id,
+            user_id,
+            last_read_event_id,
         )
         defer.returnValue(ret)
 
-    def _get_unread_counts_by_receipt_txn(self, txn, room_id, user_id,
-                                          last_read_event_id):
+    def _get_unread_counts_by_receipt_txn(
+        self, txn, room_id, user_id, last_read_event_id
+    ):
         sql = (
             "SELECT stream_ordering"
             " FROM events"
             " WHERE room_id = ? AND event_id = ?"
         )
-        txn.execute(
-            sql, (room_id, last_read_event_id)
-        )
+        txn.execute(sql, (room_id, last_read_event_id))
         results = txn.fetchall()
         if len(results) == 0:
             return {"notify_count": 0, "highlight_count": 0}
@@ -138,10 +141,13 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         row = txn.fetchone()
         notify_count = row[0] if row else 0
 
-        txn.execute("""
+        txn.execute(
+            """
             SELECT notif_count FROM event_push_summary
             WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
-        """, (room_id, user_id, stream_ordering,))
+        """,
+            (room_id, user_id, stream_ordering),
+        )
         rows = txn.fetchall()
         if rows:
             notify_count += rows[0][0]
@@ -161,10 +167,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         row = txn.fetchone()
         highlight_count = row[0] if row else 0
 
-        return {
-            "notify_count": notify_count,
-            "highlight_count": highlight_count,
-        }
+        return {"notify_count": notify_count, "highlight_count": highlight_count}
 
     @defer.inlineCallbacks
     def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering):
@@ -175,6 +178,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             )
             txn.execute(sql, (min_stream_ordering, max_stream_ordering))
             return [r[0] for r in txn]
+
         ret = yield self.runInteraction("get_push_action_users_in_range", f)
         defer.returnValue(ret)
 
@@ -223,12 +227,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.stream_ordering <= ?"
                 " ORDER BY ep.stream_ordering ASC LIMIT ?"
             )
-            args = [
-                user_id, user_id,
-                min_stream_ordering, max_stream_ordering, limit,
-            ]
+            args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
             txn.execute(sql, args)
             return txn.fetchall()
+
         after_read_receipt = yield self.runInteraction(
             "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
         )
@@ -253,12 +255,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.stream_ordering <= ?"
                 " ORDER BY ep.stream_ordering ASC LIMIT ?"
             )
-            args = [
-                user_id, user_id,
-                min_stream_ordering, max_stream_ordering, limit,
-            ]
+            args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
             txn.execute(sql, args)
             return txn.fetchall()
+
         no_read_receipt = yield self.runInteraction(
             "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
         )
@@ -269,7 +269,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "room_id": row[1],
                 "stream_ordering": row[2],
                 "actions": _deserialize_action(row[3], row[4]),
-            } for row in after_read_receipt + no_read_receipt
+            }
+            for row in after_read_receipt + no_read_receipt
         ]
 
         # Now sort it so it's ordered correctly, since currently it will
@@ -326,12 +327,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.stream_ordering <= ?"
                 " ORDER BY ep.stream_ordering DESC LIMIT ?"
             )
-            args = [
-                user_id, user_id,
-                min_stream_ordering, max_stream_ordering, limit,
-            ]
+            args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
             txn.execute(sql, args)
             return txn.fetchall()
+
         after_read_receipt = yield self.runInteraction(
             "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
         )
@@ -356,12 +355,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.stream_ordering <= ?"
                 " ORDER BY ep.stream_ordering DESC LIMIT ?"
             )
-            args = [
-                user_id, user_id,
-                min_stream_ordering, max_stream_ordering, limit,
-            ]
+            args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
             txn.execute(sql, args)
             return txn.fetchall()
+
         no_read_receipt = yield self.runInteraction(
             "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
         )
@@ -374,7 +371,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "stream_ordering": row[2],
                 "actions": _deserialize_action(row[3], row[4]),
                 "received_ts": row[5],
-            } for row in after_read_receipt + no_read_receipt
+            }
+            for row in after_read_receipt + no_read_receipt
         ]
 
         # Now sort it so it's ordered correctly, since currently it will
@@ -386,6 +384,36 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         # Now return the first `limit`
         defer.returnValue(notifs[:limit])
 
+    def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering):
+        """A fast check to see if there might be something to push for the
+        user since the given stream ordering. May return false positives.
+
+        Useful to know whether to bother starting a pusher on start up or not.
+
+        Args:
+            user_id (str)
+            min_stream_ordering (int)
+
+        Returns:
+            Deferred[bool]: True if there may be push to process, False if
+            there definitely isn't.
+        """
+
+        def _get_if_maybe_push_in_range_for_user_txn(txn):
+            sql = """
+                SELECT 1 FROM event_push_actions
+                WHERE user_id = ? AND stream_ordering > ?
+                LIMIT 1
+            """
+
+            txn.execute(sql, (user_id, min_stream_ordering))
+            return bool(txn.fetchone())
+
+        return self.runInteraction(
+            "get_if_maybe_push_in_range_for_user",
+            _get_if_maybe_push_in_range_for_user_txn,
+        )
+
     def add_push_actions_to_staging(self, event_id, user_id_actions):
         """Add the push actions for the event to the push action staging area.
 
@@ -424,10 +452,13 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 VALUES (?, ?, ?, ?, ?)
             """
 
-            txn.executemany(sql, (
-                _gen_entry(user_id, actions)
-                for user_id, actions in iteritems(user_id_actions)
-            ))
+            txn.executemany(
+                sql,
+                (
+                    _gen_entry(user_id, actions)
+                    for user_id, actions in iteritems(user_id_actions)
+                ),
+            )
 
         return self.runInteraction(
             "add_push_actions_to_staging", _add_push_actions_to_staging_txn
@@ -445,9 +476,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         try:
             res = yield self._simple_delete(
                 table="event_push_actions_staging",
-                keyvalues={
-                    "event_id": event_id,
-                },
+                keyvalues={"event_id": event_id},
                 desc="remove_push_actions_from_staging",
             )
             defer.returnValue(res)
@@ -456,7 +485,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             # another exception here really isn't helpful - there's nothing
             # the caller can do about it. Just log the exception and move on.
             logger.exception(
-                "Error removing push actions after event persistence failure",
+                "Error removing push actions after event persistence failure"
             )
 
     def _find_stream_orderings_for_times(self):
@@ -473,16 +502,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
         )
         logger.info(
-            "Found stream ordering 1 month ago: it's %d",
-            self.stream_ordering_month_ago
+            "Found stream ordering 1 month ago: it's %d", self.stream_ordering_month_ago
         )
         logger.info("Searching for stream ordering 1 day ago")
         self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn(
             txn, self._clock.time_msec() - 24 * 60 * 60 * 1000
         )
         logger.info(
-            "Found stream ordering 1 day ago: it's %d",
-            self.stream_ordering_day_ago
+            "Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago
         )
 
     def find_first_stream_ordering_after_ts(self, ts):
@@ -601,16 +628,17 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             index_name="event_push_actions_highlights_index",
             table="event_push_actions",
             columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
-            where_clause="highlight=1"
+            where_clause="highlight=1",
         )
 
         self._doing_notif_rotation = False
         self._rotate_notif_loop = self._clock.looping_call(
-            self._start_rotate_notifs, 30 * 60 * 1000,
+            self._start_rotate_notifs, 30 * 60 * 1000
         )
 
-    def _set_push_actions_for_event_and_users_txn(self, txn, events_and_contexts,
-                                                  all_events_and_contexts):
+    def _set_push_actions_for_event_and_users_txn(
+        self, txn, events_and_contexts, all_events_and_contexts
+    ):
         """Handles moving push actions from staging table to main
         event_push_actions table for all events in `events_and_contexts`.
 
@@ -637,43 +665,44 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         """
 
         if events_and_contexts:
-            txn.executemany(sql, (
+            txn.executemany(
+                sql,
                 (
-                    event.room_id, event.internal_metadata.stream_ordering,
-                    event.depth, event.event_id,
-                )
-                for event, _ in events_and_contexts
-            ))
+                    (
+                        event.room_id,
+                        event.internal_metadata.stream_ordering,
+                        event.depth,
+                        event.event_id,
+                    )
+                    for event, _ in events_and_contexts
+                ),
+            )
 
         for event, _ in events_and_contexts:
             user_ids = self._simple_select_onecol_txn(
                 txn,
                 table="event_push_actions_staging",
-                keyvalues={
-                    "event_id": event.event_id,
-                },
+                keyvalues={"event_id": event.event_id},
                 retcol="user_id",
             )
 
             for uid in user_ids:
                 txn.call_after(
                     self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
-                    (event.room_id, uid,)
+                    (event.room_id, uid),
                 )
 
         # Now we delete the staging area for *all* events that were being
         # persisted.
         txn.executemany(
             "DELETE FROM event_push_actions_staging WHERE event_id = ?",
-            (
-                (event.event_id,)
-                for event, _ in all_events_and_contexts
-            )
+            ((event.event_id,) for event, _ in all_events_and_contexts),
         )
 
     @defer.inlineCallbacks
-    def get_push_actions_for_user(self, user_id, before=None, limit=50,
-                                  only_highlight=False):
+    def get_push_actions_for_user(
+        self, user_id, before=None, limit=50, only_highlight=False
+    ):
         def f(txn):
             before_clause = ""
             if before:
@@ -697,15 +726,12 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
                 " WHERE epa.event_id = e.event_id"
                 " AND epa.user_id = ? %s"
                 " ORDER BY epa.stream_ordering DESC"
-                " LIMIT ?"
-                % (before_clause,)
+                " LIMIT ?" % (before_clause,)
             )
             txn.execute(sql, args)
             return self.cursor_to_dict(txn)
 
-        push_actions = yield self.runInteraction(
-            "get_push_actions_for_user", f
-        )
+        push_actions = yield self.runInteraction("get_push_actions_for_user", f)
         for pa in push_actions:
             pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
         defer.returnValue(push_actions)
@@ -723,6 +749,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             )
             txn.execute(sql, (stream_ordering,))
             return txn.fetchone()
+
         result = yield self.runInteraction("get_time_of_last_push_action_before", f)
         defer.returnValue(result[0] if result else None)
 
@@ -731,24 +758,24 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         def f(txn):
             txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
             return txn.fetchone()
-        result = yield self.runInteraction(
-            "get_latest_push_action_stream_ordering", f
-        )
+
+        result = yield self.runInteraction("get_latest_push_action_stream_ordering", f)
         defer.returnValue(result[0] or 0)
 
     def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
         # Sad that we have to blow away the cache for the whole room here
         txn.call_after(
             self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
-            (room_id,)
+            (room_id,),
         )
         txn.execute(
             "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
-            (room_id, event_id)
+            (room_id, event_id),
         )
 
-    def _remove_old_push_actions_before_txn(self, txn, room_id, user_id,
-                                            stream_ordering):
+    def _remove_old_push_actions_before_txn(
+        self, txn, room_id, user_id, stream_ordering
+    ):
         """
         Purges old push actions for a user and room before a given
         stream_ordering.
@@ -765,7 +792,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         """
         txn.call_after(
             self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
-            (room_id, user_id, )
+            (room_id, user_id),
         )
 
         # We need to join on the events table to get the received_ts for
@@ -781,13 +808,16 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             " WHERE user_id = ? AND room_id = ? AND "
             " stream_ordering <= ?"
             " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
-            (user_id, room_id, stream_ordering, self.stream_ordering_month_ago)
+            (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
         )
 
-        txn.execute("""
+        txn.execute(
+            """
             DELETE FROM event_push_summary
             WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
-        """, (room_id, user_id, stream_ordering))
+        """,
+            (room_id, user_id, stream_ordering),
+        )
 
     def _start_rotate_notifs(self):
         return run_as_background_process("rotate_notifs", self._rotate_notifs)
@@ -803,8 +833,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
                 logger.info("Rotating notifications")
 
                 caught_up = yield self.runInteraction(
-                    "_rotate_notifs",
-                    self._rotate_notifs_txn
+                    "_rotate_notifs", self._rotate_notifs_txn
                 )
                 if caught_up:
                     break
@@ -826,11 +855,14 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
 
         # We don't to try and rotate millions of rows at once, so we cap the
         # maximum stream ordering we'll rotate before.
-        txn.execute("""
+        txn.execute(
+            """
             SELECT stream_ordering FROM event_push_actions
             WHERE stream_ordering > ?
             ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
-        """, (old_rotate_stream_ordering, self._rotate_count))
+        """,
+            (old_rotate_stream_ordering, self._rotate_count),
+        )
         stream_row = txn.fetchone()
         if stream_row:
             offset_stream_ordering, = stream_row
@@ -874,7 +906,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             LEFT JOIN event_push_summary AS old USING (user_id, room_id)
         """
 
-        txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering,))
+        txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering))
         rows = txn.fetchall()
 
         logger.info("Rotating notifications, handling %d rows", len(rows))
@@ -892,8 +924,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
                     "notif_count": row[2],
                     "stream_ordering": row[3],
                 }
-                for row in rows if row[4] is None
-            ]
+                for row in rows
+                if row[4] is None
+            ],
         )
 
         txn.executemany(
@@ -901,20 +934,20 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
                 UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
                 WHERE user_id = ? AND room_id = ?
             """,
-            ((row[2], row[3], row[0], row[1],) for row in rows if row[4] is not None)
+            ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None),
         )
 
         txn.execute(
             "DELETE FROM event_push_actions"
             " WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0",
-            (old_rotate_stream_ordering, rotate_to_stream_ordering,)
+            (old_rotate_stream_ordering, rotate_to_stream_ordering),
         )
 
         logger.info("Rotating notifications, deleted %s push actions", txn.rowcount)
 
         txn.execute(
             "UPDATE event_push_summary_stream_ordering SET stream_ordering = ?",
-            (rotate_to_stream_ordering,)
+            (rotate_to_stream_ordering,),
         )