summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-10-09 07:37:51 -0400
committerGitHub <noreply@github.com>2020-10-09 07:37:51 -0400
commitfe0f4a3591302176c7eea48a54f6ed83d9eb4aa9 (patch)
tree831e5624bda8c4540ed5df42f1c78b0304d4d5f0 /synapse/storage/databases/main
parentRemove the deprecated Handlers object (#8494) (diff)
downloadsynapse-fe0f4a3591302176c7eea48a54f6ed83d9eb4aa9.tar.xz
Move additional tasks to the background worker, part 3 (#8489)
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/censor_events.py15
-rw-r--r--synapse/storage/databases/main/devices.py196
-rw-r--r--synapse/storage/databases/main/event_federation.py60
-rw-r--r--synapse/storage/databases/main/event_push_actions.py259
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py2
-rw-r--r--synapse/storage/databases/main/registration.py11
6 files changed, 272 insertions, 271 deletions
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 4bb2b9c28c..849bd5ba7a 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -17,7 +17,7 @@ import logging
 from typing import TYPE_CHECKING
 
 from synapse.events.utils import prune_event_dict
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
@@ -35,14 +35,13 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
     def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
-        def _censor_redactions():
-            return run_as_background_process(
-                "_censor_redactions", self._censor_redactions
-            )
-
-        if self.hs.config.redaction_retention_period is not None:
-            hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
+        if (
+            hs.config.run_background_tasks
+            and self.hs.config.redaction_retention_period is not None
+        ):
+            hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000)
 
+    @wrap_as_background_process("_censor_redactions")
     async def _censor_redactions(self):
         """Censors all redactions older than the configured period that haven't
         been censored yet.
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 2d0a6408b5..88fd97e1df 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -25,7 +25,7 @@ from synapse.logging.opentracing import (
     trace,
     whitelisted_homeserver,
 )
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
@@ -48,6 +48,14 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
 
 
 class DeviceWorkerStore(SQLBaseStore):
+    def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        if hs.config.run_background_tasks:
+            self._clock.looping_call(
+                self._prune_old_outbound_device_pokes, 60 * 60 * 1000
+            )
+
     async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
         """Retrieve a device. Only returns devices that are not marked as
         hidden.
@@ -772,6 +780,98 @@ class DeviceWorkerStore(SQLBaseStore):
         )
         return count >= 1
 
+    @wrap_as_background_process("prune_old_outbound_device_pokes")
+    async def _prune_old_outbound_device_pokes(
+        self, prune_age: int = 24 * 60 * 60 * 1000
+    ) -> None:
+        """Delete old entries out of the device_lists_outbound_pokes to ensure
+        that we don't fill up due to dead servers.
+
+        Normally, we try to send device updates as a delta since a previous known point:
+        this is done by setting the prev_id in the m.device_list_update EDU. However,
+        for that to work, we have to have a complete record of each change to
+        each device, which can add up to quite a lot of data.
+
+        An alternative mechanism is that, if the remote server sees that it has missed
+        an entry in the stream_id sequence for a given user, it will request a full
+        list of that user's devices. Hence, we can reduce the amount of data we have to
+        store (and transmit in some future transaction), by clearing almost everything
+        for a given destination out of the database, and having the remote server
+        resync.
+
+        All we need to do is make sure we keep at least one row for each
+        (user, destination) pair, to remind us to send a m.device_list_update EDU for
+        that user when the destination comes back. It doesn't matter which device
+        we keep.
+        """
+        yesterday = self._clock.time_msec() - prune_age
+
+        def _prune_txn(txn):
+            # look for (user, destination) pairs which have an update older than
+            # the cutoff.
+            #
+            # For each pair, we also need to know the most recent stream_id, and
+            # an arbitrary device_id at that stream_id.
+            select_sql = """
+            SELECT
+                dlop1.destination,
+                dlop1.user_id,
+                MAX(dlop1.stream_id) AS stream_id,
+                (SELECT MIN(dlop2.device_id) AS device_id FROM
+                    device_lists_outbound_pokes dlop2
+                    WHERE dlop2.destination = dlop1.destination AND
+                      dlop2.user_id=dlop1.user_id AND
+                      dlop2.stream_id=MAX(dlop1.stream_id)
+                )
+            FROM device_lists_outbound_pokes dlop1
+                GROUP BY destination, user_id
+                HAVING min(ts) < ? AND count(*) > 1
+            """
+
+            txn.execute(select_sql, (yesterday,))
+            rows = txn.fetchall()
+
+            if not rows:
+                return
+
+            logger.info(
+                "Pruning old outbound device list updates for %i users/destinations: %s",
+                len(rows),
+                shortstr((row[0], row[1]) for row in rows),
+            )
+
+            # we want to keep the update with the highest stream_id for each user.
+            #
+            # there might be more than one update (with different device_ids) with the
+            # same stream_id, so we also delete all but one rows with the max stream id.
+            delete_sql = """
+                DELETE FROM device_lists_outbound_pokes
+                WHERE destination = ? AND user_id = ? AND (
+                    stream_id < ? OR
+                    (stream_id = ? AND device_id != ?)
+                )
+            """
+            count = 0
+            for (destination, user_id, stream_id, device_id) in rows:
+                txn.execute(
+                    delete_sql, (destination, user_id, stream_id, stream_id, device_id)
+                )
+                count += txn.rowcount
+
+            # Since we've deleted unsent deltas, we need to remove the entry
+            # of last successful sent so that the prev_ids are correctly set.
+            sql = """
+                DELETE FROM device_lists_outbound_last_success
+                WHERE destination = ? AND user_id = ?
+            """
+            txn.executemany(sql, ((row[0], row[1]) for row in rows))
+
+            logger.info("Pruned %d device list outbound pokes", count)
+
+        await self.db_pool.runInteraction(
+            "_prune_old_outbound_device_pokes", _prune_txn,
+        )
+
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
@@ -908,8 +1008,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             name="device_id_exists", keylen=2, max_entries=10000
         )
 
-        self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
-
     async def store_device(
         self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
     ) -> bool:
@@ -1267,95 +1365,3 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 for device_id in device_ids
             ],
         )
-
-    def _prune_old_outbound_device_pokes(self, prune_age: int = 24 * 60 * 60 * 1000):
-        """Delete old entries out of the device_lists_outbound_pokes to ensure
-        that we don't fill up due to dead servers.
-
-        Normally, we try to send device updates as a delta since a previous known point:
-        this is done by setting the prev_id in the m.device_list_update EDU. However,
-        for that to work, we have to have a complete record of each change to
-        each device, which can add up to quite a lot of data.
-
-        An alternative mechanism is that, if the remote server sees that it has missed
-        an entry in the stream_id sequence for a given user, it will request a full
-        list of that user's devices. Hence, we can reduce the amount of data we have to
-        store (and transmit in some future transaction), by clearing almost everything
-        for a given destination out of the database, and having the remote server
-        resync.
-
-        All we need to do is make sure we keep at least one row for each
-        (user, destination) pair, to remind us to send a m.device_list_update EDU for
-        that user when the destination comes back. It doesn't matter which device
-        we keep.
-        """
-        yesterday = self._clock.time_msec() - prune_age
-
-        def _prune_txn(txn):
-            # look for (user, destination) pairs which have an update older than
-            # the cutoff.
-            #
-            # For each pair, we also need to know the most recent stream_id, and
-            # an arbitrary device_id at that stream_id.
-            select_sql = """
-            SELECT
-                dlop1.destination,
-                dlop1.user_id,
-                MAX(dlop1.stream_id) AS stream_id,
-                (SELECT MIN(dlop2.device_id) AS device_id FROM
-                    device_lists_outbound_pokes dlop2
-                    WHERE dlop2.destination = dlop1.destination AND
-                      dlop2.user_id=dlop1.user_id AND
-                      dlop2.stream_id=MAX(dlop1.stream_id)
-                )
-            FROM device_lists_outbound_pokes dlop1
-                GROUP BY destination, user_id
-                HAVING min(ts) < ? AND count(*) > 1
-            """
-
-            txn.execute(select_sql, (yesterday,))
-            rows = txn.fetchall()
-
-            if not rows:
-                return
-
-            logger.info(
-                "Pruning old outbound device list updates for %i users/destinations: %s",
-                len(rows),
-                shortstr((row[0], row[1]) for row in rows),
-            )
-
-            # we want to keep the update with the highest stream_id for each user.
-            #
-            # there might be more than one update (with different device_ids) with the
-            # same stream_id, so we also delete all but one rows with the max stream id.
-            delete_sql = """
-                DELETE FROM device_lists_outbound_pokes
-                WHERE destination = ? AND user_id = ? AND (
-                    stream_id < ? OR
-                    (stream_id = ? AND device_id != ?)
-                )
-            """
-            count = 0
-            for (destination, user_id, stream_id, device_id) in rows:
-                txn.execute(
-                    delete_sql, (destination, user_id, stream_id, stream_id, device_id)
-                )
-                count += txn.rowcount
-
-            # Since we've deleted unsent deltas, we need to remove the entry
-            # of last successful sent so that the prev_ids are correctly set.
-            sql = """
-                DELETE FROM device_lists_outbound_last_success
-                WHERE destination = ? AND user_id = ?
-            """
-            txn.executemany(sql, ((row[0], row[1]) for row in rows))
-
-            logger.info("Pruned %d device list outbound pokes", count)
-
-        return run_as_background_process(
-            "prune_old_outbound_device_pokes",
-            self.db_pool.runInteraction,
-            "_prune_old_outbound_device_pokes",
-            _prune_txn,
-        )
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 6d3689c09e..a6279a6c13 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -19,7 +19,7 @@ from typing import Dict, Iterable, List, Set, Tuple
 
 from synapse.api.errors import StoreError
 from synapse.events import EventBase
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
@@ -32,6 +32,14 @@ logger = logging.getLogger(__name__)
 
 
 class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
+    def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        if hs.config.run_background_tasks:
+            hs.get_clock().looping_call(
+                self._delete_old_forward_extrem_cache, 60 * 60 * 1000
+            )
+
     async def get_auth_chain(
         self, event_ids: Collection[str], include_given: bool = False
     ) -> List[EventBase]:
@@ -586,6 +594,28 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         return [row["event_id"] for row in rows]
 
+    @wrap_as_background_process("delete_old_forward_extrem_cache")
+    async def _delete_old_forward_extrem_cache(self) -> None:
+        def _delete_old_forward_extrem_cache_txn(txn):
+            # Delete entries older than a month, while making sure we don't delete
+            # the only entries for a room.
+            sql = """
+                DELETE FROM stream_ordering_to_exterm
+                WHERE
+                room_id IN (
+                    SELECT room_id
+                    FROM stream_ordering_to_exterm
+                    WHERE stream_ordering > ?
+                ) AND stream_ordering < ?
+            """
+            txn.execute(
+                sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
+            )
+
+        await self.db_pool.runInteraction(
+            "_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn,
+        )
+
 
 class EventFederationStore(EventFederationWorkerStore):
     """ Responsible for storing and serving up the various graphs associated
@@ -606,34 +636,6 @@ class EventFederationStore(EventFederationWorkerStore):
             self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
         )
 
-        hs.get_clock().looping_call(
-            self._delete_old_forward_extrem_cache, 60 * 60 * 1000
-        )
-
-    def _delete_old_forward_extrem_cache(self):
-        def _delete_old_forward_extrem_cache_txn(txn):
-            # Delete entries older than a month, while making sure we don't delete
-            # the only entries for a room.
-            sql = """
-                DELETE FROM stream_ordering_to_exterm
-                WHERE
-                room_id IN (
-                    SELECT room_id
-                    FROM stream_ordering_to_exterm
-                    WHERE stream_ordering > ?
-                ) AND stream_ordering < ?
-            """
-            txn.execute(
-                sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
-            )
-
-        return run_as_background_process(
-            "delete_old_forward_extrem_cache",
-            self.db_pool.runInteraction,
-            "_delete_old_forward_extrem_cache",
-            _delete_old_forward_extrem_cache_txn,
-        )
-
     async def clean_room_for_join(self, room_id):
         return await self.db_pool.runInteraction(
             "clean_room_for_join", self._clean_room_for_join_txn, room_id
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 80f3b4d740..2e56dfaf31 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -13,15 +13,14 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 from typing import Dict, List, Optional, Tuple, Union
 
 import attr
 
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 
@@ -81,8 +80,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         self.find_stream_orderings_looping_call = self._clock.looping_call(
             self._find_stream_orderings_for_times, 10 * 60 * 1000
         )
+
         self._rotate_delay = 3
         self._rotate_count = 10000
+        self._doing_notif_rotation = False
+        if hs.config.run_background_tasks:
+            self._rotate_notif_loop = self._clock.looping_call(
+                self._rotate_notifs, 30 * 60 * 1000
+            )
 
     @cached(num_args=3, tree=True, max_entries=5000)
     async def get_unread_event_push_actions_by_room_for_user(
@@ -514,15 +519,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "Error removing push actions after event persistence failure"
             )
 
-    def _find_stream_orderings_for_times(self):
-        return run_as_background_process(
-            "event_push_action_stream_orderings",
-            self.db_pool.runInteraction,
+    @wrap_as_background_process("event_push_action_stream_orderings")
+    async def _find_stream_orderings_for_times(self) -> None:
+        await self.db_pool.runInteraction(
             "_find_stream_orderings_for_times",
             self._find_stream_orderings_for_times_txn,
         )
 
-    def _find_stream_orderings_for_times_txn(self, txn):
+    def _find_stream_orderings_for_times_txn(self, txn: LoggingTransaction) -> None:
         logger.info("Searching for stream ordering 1 month ago")
         self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
             txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
@@ -652,129 +656,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         )
         return result[0] if result else None
 
-
-class EventPushActionsStore(EventPushActionsWorkerStore):
-    EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
-
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        super().__init__(database, db_conn, hs)
-
-        self.db_pool.updates.register_background_index_update(
-            self.EPA_HIGHLIGHT_INDEX,
-            index_name="event_push_actions_u_highlight",
-            table="event_push_actions",
-            columns=["user_id", "stream_ordering"],
-        )
-
-        self.db_pool.updates.register_background_index_update(
-            "event_push_actions_highlights_index",
-            index_name="event_push_actions_highlights_index",
-            table="event_push_actions",
-            columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
-            where_clause="highlight=1",
-        )
-
-        self._doing_notif_rotation = False
-        self._rotate_notif_loop = self._clock.looping_call(
-            self._start_rotate_notifs, 30 * 60 * 1000
-        )
-
-    async def get_push_actions_for_user(
-        self, user_id, before=None, limit=50, only_highlight=False
-    ):
-        def f(txn):
-            before_clause = ""
-            if before:
-                before_clause = "AND epa.stream_ordering < ?"
-                args = [user_id, before, limit]
-            else:
-                args = [user_id, limit]
-
-            if only_highlight:
-                if len(before_clause) > 0:
-                    before_clause += " "
-                before_clause += "AND epa.highlight = 1"
-
-            # NB. This assumes event_ids are globally unique since
-            # it makes the query easier to index
-            sql = (
-                "SELECT epa.event_id, epa.room_id,"
-                " epa.stream_ordering, epa.topological_ordering,"
-                " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
-                " FROM event_push_actions epa, events e"
-                " WHERE epa.event_id = e.event_id"
-                " AND epa.user_id = ? %s"
-                " AND epa.notif = 1"
-                " ORDER BY epa.stream_ordering DESC"
-                " LIMIT ?" % (before_clause,)
-            )
-            txn.execute(sql, args)
-            return self.db_pool.cursor_to_dict(txn)
-
-        push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
-        for pa in push_actions:
-            pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
-        return push_actions
-
-    async def get_latest_push_action_stream_ordering(self):
-        def f(txn):
-            txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
-            return txn.fetchone()
-
-        result = await self.db_pool.runInteraction(
-            "get_latest_push_action_stream_ordering", f
-        )
-        return result[0] or 0
-
-    def _remove_old_push_actions_before_txn(
-        self, txn, room_id, user_id, stream_ordering
-    ):
-        """
-        Purges old push actions for a user and room before a given
-        stream_ordering.
-
-        We however keep a months worth of highlighted notifications, so that
-        users can still get a list of recent highlights.
-
-        Args:
-            txn: The transcation
-            room_id: Room ID to delete from
-            user_id: user ID to delete for
-            stream_ordering: The lowest stream ordering which will
-                                  not be deleted.
-        """
-        txn.call_after(
-            self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
-            (room_id, user_id),
-        )
-
-        # We need to join on the events table to get the received_ts for
-        # event_push_actions and sqlite won't let us use a join in a delete so
-        # we can't just delete where received_ts < x. Furthermore we can
-        # only identify event_push_actions by a tuple of room_id, event_id
-        # we we can't use a subquery.
-        # Instead, we look up the stream ordering for the last event in that
-        # room received before the threshold time and delete event_push_actions
-        # in the room with a stream_odering before that.
-        txn.execute(
-            "DELETE FROM event_push_actions "
-            " WHERE user_id = ? AND room_id = ? AND "
-            " stream_ordering <= ?"
-            " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
-            (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
-        )
-
-        txn.execute(
-            """
-            DELETE FROM event_push_summary
-            WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
-        """,
-            (room_id, user_id, stream_ordering),
-        )
-
-    def _start_rotate_notifs(self):
-        return run_as_background_process("rotate_notifs", self._rotate_notifs)
-
+    @wrap_as_background_process("rotate_notifs")
     async def _rotate_notifs(self):
         if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
             return
@@ -954,6 +836,121 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         )
 
 
+class EventPushActionsStore(EventPushActionsWorkerStore):
+    EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
+
+    def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        self.db_pool.updates.register_background_index_update(
+            self.EPA_HIGHLIGHT_INDEX,
+            index_name="event_push_actions_u_highlight",
+            table="event_push_actions",
+            columns=["user_id", "stream_ordering"],
+        )
+
+        self.db_pool.updates.register_background_index_update(
+            "event_push_actions_highlights_index",
+            index_name="event_push_actions_highlights_index",
+            table="event_push_actions",
+            columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
+            where_clause="highlight=1",
+        )
+
+    async def get_push_actions_for_user(
+        self, user_id, before=None, limit=50, only_highlight=False
+    ):
+        def f(txn):
+            before_clause = ""
+            if before:
+                before_clause = "AND epa.stream_ordering < ?"
+                args = [user_id, before, limit]
+            else:
+                args = [user_id, limit]
+
+            if only_highlight:
+                if len(before_clause) > 0:
+                    before_clause += " "
+                before_clause += "AND epa.highlight = 1"
+
+            # NB. This assumes event_ids are globally unique since
+            # it makes the query easier to index
+            sql = (
+                "SELECT epa.event_id, epa.room_id,"
+                " epa.stream_ordering, epa.topological_ordering,"
+                " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
+                " FROM event_push_actions epa, events e"
+                " WHERE epa.event_id = e.event_id"
+                " AND epa.user_id = ? %s"
+                " AND epa.notif = 1"
+                " ORDER BY epa.stream_ordering DESC"
+                " LIMIT ?" % (before_clause,)
+            )
+            txn.execute(sql, args)
+            return self.db_pool.cursor_to_dict(txn)
+
+        push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
+        for pa in push_actions:
+            pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
+        return push_actions
+
+    async def get_latest_push_action_stream_ordering(self):
+        def f(txn):
+            txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
+            return txn.fetchone()
+
+        result = await self.db_pool.runInteraction(
+            "get_latest_push_action_stream_ordering", f
+        )
+        return result[0] or 0
+
+    def _remove_old_push_actions_before_txn(
+        self, txn, room_id, user_id, stream_ordering
+    ):
+        """
+        Purges old push actions for a user and room before a given
+        stream_ordering.
+
+        We however keep a months worth of highlighted notifications, so that
+        users can still get a list of recent highlights.
+
+        Args:
+            txn: The transcation
+            room_id: Room ID to delete from
+            user_id: user ID to delete for
+            stream_ordering: The lowest stream ordering which will
+                                  not be deleted.
+        """
+        txn.call_after(
+            self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+            (room_id, user_id),
+        )
+
+        # We need to join on the events table to get the received_ts for
+        # event_push_actions and sqlite won't let us use a join in a delete so
+        # we can't just delete where received_ts < x. Furthermore we can
+        # only identify event_push_actions by a tuple of room_id, event_id
+        # we we can't use a subquery.
+        # Instead, we look up the stream ordering for the last event in that
+        # room received before the threshold time and delete event_push_actions
+        # in the room with a stream_odering before that.
+        txn.execute(
+            "DELETE FROM event_push_actions "
+            " WHERE user_id = ? AND room_id = ? AND "
+            " stream_ordering <= ?"
+            " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
+            (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
+        )
+
+        txn.execute(
+            """
+            DELETE FROM event_push_summary
+            WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
+        """,
+            (room_id, user_id, stream_ordering),
+        )
+
+
 def _action_has_highlight(actions):
     for action in actions:
         try:
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index c66f558567..d788dc0fc6 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -15,6 +15,7 @@
 import logging
 from typing import Dict, List
 
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool, make_in_list_sql_clause
 from synapse.util.caches.descriptors import cached
@@ -127,6 +128,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
             desc="user_last_seen_monthly_active",
         )
 
+    @wrap_as_background_process("reap_monthly_active_users")
     async def reap_monthly_active_users(self):
         """Cleans out monthly active user table to ensure that no stale
         entries exist.
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 7fd7b0b952..236d3cdbe3 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -20,10 +20,7 @@ from typing import Any, Dict, List, Optional, Tuple
 
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
-from synapse.metrics.background_process_metrics import (
-    run_as_background_process,
-    wrap_as_background_process,
-)
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
 from synapse.storage.types import Cursor
@@ -53,10 +50,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         self._account_validity = hs.config.account_validity
         if hs.config.run_background_tasks and self._account_validity.enabled:
             self._clock.call_later(
-                0.0,
-                run_as_background_process,
-                "account_validity_set_expiration_dates",
-                self._set_expiration_date_when_missing,
+                0.0, self._set_expiration_date_when_missing,
             )
 
         # Create a background job for culling expired 3PID validity tokens
@@ -812,6 +806,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             self.clock.time_msec(),
         )
 
+    @wrap_as_background_process("account_validity_set_expiration_dates")
     async def _set_expiration_date_when_missing(self):
         """
         Retrieves the list of registered users that don't have an expiration date, and