summary refs log tree commit diff
path: root/synapse/storage/databases/main/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/devices.py')
-rw-r--r--synapse/storage/databases/main/devices.py196
1 files changed, 101 insertions, 95 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 2d0a6408b5..88fd97e1df 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -25,7 +25,7 @@ from synapse.logging.opentracing import (
     trace,
     whitelisted_homeserver,
 )
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
@@ -48,6 +48,14 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
 
 
 class DeviceWorkerStore(SQLBaseStore):
+    def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        if hs.config.run_background_tasks:
+            self._clock.looping_call(
+                self._prune_old_outbound_device_pokes, 60 * 60 * 1000
+            )
+
     async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
         """Retrieve a device. Only returns devices that are not marked as
         hidden.
@@ -772,6 +780,98 @@ class DeviceWorkerStore(SQLBaseStore):
         )
         return count >= 1
 
+    @wrap_as_background_process("prune_old_outbound_device_pokes")
+    async def _prune_old_outbound_device_pokes(
+        self, prune_age: int = 24 * 60 * 60 * 1000
+    ) -> None:
+        """Delete old entries out of the device_lists_outbound_pokes to ensure
+        that we don't fill up due to dead servers.
+
+        Normally, we try to send device updates as a delta since a previous known point:
+        this is done by setting the prev_id in the m.device_list_update EDU. However,
+        for that to work, we have to have a complete record of each change to
+        each device, which can add up to quite a lot of data.
+
+        An alternative mechanism is that, if the remote server sees that it has missed
+        an entry in the stream_id sequence for a given user, it will request a full
+        list of that user's devices. Hence, we can reduce the amount of data we have to
+        store (and transmit in some future transaction), by clearing almost everything
+        for a given destination out of the database, and having the remote server
+        resync.
+
+        All we need to do is make sure we keep at least one row for each
+        (user, destination) pair, to remind us to send a m.device_list_update EDU for
+        that user when the destination comes back. It doesn't matter which device
+        we keep.
+        """
+        yesterday = self._clock.time_msec() - prune_age
+
+        def _prune_txn(txn):
+            # look for (user, destination) pairs which have an update older than
+            # the cutoff.
+            #
+            # For each pair, we also need to know the most recent stream_id, and
+            # an arbitrary device_id at that stream_id.
+            select_sql = """
+            SELECT
+                dlop1.destination,
+                dlop1.user_id,
+                MAX(dlop1.stream_id) AS stream_id,
+                (SELECT MIN(dlop2.device_id) AS device_id FROM
+                    device_lists_outbound_pokes dlop2
+                    WHERE dlop2.destination = dlop1.destination AND
+                      dlop2.user_id=dlop1.user_id AND
+                      dlop2.stream_id=MAX(dlop1.stream_id)
+                )
+            FROM device_lists_outbound_pokes dlop1
+                GROUP BY destination, user_id
+                HAVING min(ts) < ? AND count(*) > 1
+            """
+
+            txn.execute(select_sql, (yesterday,))
+            rows = txn.fetchall()
+
+            if not rows:
+                return
+
+            logger.info(
+                "Pruning old outbound device list updates for %i users/destinations: %s",
+                len(rows),
+                shortstr((row[0], row[1]) for row in rows),
+            )
+
+            # we want to keep the update with the highest stream_id for each user.
+            #
+            # there might be more than one update (with different device_ids) with the
+            # same stream_id, so we also delete all but one rows with the max stream id.
+            delete_sql = """
+                DELETE FROM device_lists_outbound_pokes
+                WHERE destination = ? AND user_id = ? AND (
+                    stream_id < ? OR
+                    (stream_id = ? AND device_id != ?)
+                )
+            """
+            count = 0
+            for (destination, user_id, stream_id, device_id) in rows:
+                txn.execute(
+                    delete_sql, (destination, user_id, stream_id, stream_id, device_id)
+                )
+                count += txn.rowcount
+
+            # Since we've deleted unsent deltas, we need to remove the entry
+            # of last successful sent so that the prev_ids are correctly set.
+            sql = """
+                DELETE FROM device_lists_outbound_last_success
+                WHERE destination = ? AND user_id = ?
+            """
+            txn.executemany(sql, ((row[0], row[1]) for row in rows))
+
+            logger.info("Pruned %d device list outbound pokes", count)
+
+        await self.db_pool.runInteraction(
+            "_prune_old_outbound_device_pokes", _prune_txn,
+        )
+
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
@@ -908,8 +1008,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             name="device_id_exists", keylen=2, max_entries=10000
         )
 
-        self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
-
     async def store_device(
         self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
     ) -> bool:
@@ -1267,95 +1365,3 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 for device_id in device_ids
             ],
         )
-
-    def _prune_old_outbound_device_pokes(self, prune_age: int = 24 * 60 * 60 * 1000):
-        """Delete old entries out of the device_lists_outbound_pokes to ensure
-        that we don't fill up due to dead servers.
-
-        Normally, we try to send device updates as a delta since a previous known point:
-        this is done by setting the prev_id in the m.device_list_update EDU. However,
-        for that to work, we have to have a complete record of each change to
-        each device, which can add up to quite a lot of data.
-
-        An alternative mechanism is that, if the remote server sees that it has missed
-        an entry in the stream_id sequence for a given user, it will request a full
-        list of that user's devices. Hence, we can reduce the amount of data we have to
-        store (and transmit in some future transaction), by clearing almost everything
-        for a given destination out of the database, and having the remote server
-        resync.
-
-        All we need to do is make sure we keep at least one row for each
-        (user, destination) pair, to remind us to send a m.device_list_update EDU for
-        that user when the destination comes back. It doesn't matter which device
-        we keep.
-        """
-        yesterday = self._clock.time_msec() - prune_age
-
-        def _prune_txn(txn):
-            # look for (user, destination) pairs which have an update older than
-            # the cutoff.
-            #
-            # For each pair, we also need to know the most recent stream_id, and
-            # an arbitrary device_id at that stream_id.
-            select_sql = """
-            SELECT
-                dlop1.destination,
-                dlop1.user_id,
-                MAX(dlop1.stream_id) AS stream_id,
-                (SELECT MIN(dlop2.device_id) AS device_id FROM
-                    device_lists_outbound_pokes dlop2
-                    WHERE dlop2.destination = dlop1.destination AND
-                      dlop2.user_id=dlop1.user_id AND
-                      dlop2.stream_id=MAX(dlop1.stream_id)
-                )
-            FROM device_lists_outbound_pokes dlop1
-                GROUP BY destination, user_id
-                HAVING min(ts) < ? AND count(*) > 1
-            """
-
-            txn.execute(select_sql, (yesterday,))
-            rows = txn.fetchall()
-
-            if not rows:
-                return
-
-            logger.info(
-                "Pruning old outbound device list updates for %i users/destinations: %s",
-                len(rows),
-                shortstr((row[0], row[1]) for row in rows),
-            )
-
-            # we want to keep the update with the highest stream_id for each user.
-            #
-            # there might be more than one update (with different device_ids) with the
-            # same stream_id, so we also delete all but one rows with the max stream id.
-            delete_sql = """
-                DELETE FROM device_lists_outbound_pokes
-                WHERE destination = ? AND user_id = ? AND (
-                    stream_id < ? OR
-                    (stream_id = ? AND device_id != ?)
-                )
-            """
-            count = 0
-            for (destination, user_id, stream_id, device_id) in rows:
-                txn.execute(
-                    delete_sql, (destination, user_id, stream_id, stream_id, device_id)
-                )
-                count += txn.rowcount
-
-            # Since we've deleted unsent deltas, we need to remove the entry
-            # of last successful sent so that the prev_ids are correctly set.
-            sql = """
-                DELETE FROM device_lists_outbound_last_success
-                WHERE destination = ? AND user_id = ?
-            """
-            txn.executemany(sql, ((row[0], row[1]) for row in rows))
-
-            logger.info("Pruned %d device list outbound pokes", count)
-
-        return run_as_background_process(
-            "prune_old_outbound_device_pokes",
-            self.db_pool.runInteraction,
-            "_prune_old_outbound_device_pokes",
-            _prune_txn,
-        )