diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 7674c187ef..c935c7be90 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -458,10 +458,12 @@ class DeviceHandler(DeviceWorkerHandler):
async def _prune_too_many_devices(self, user_id: str) -> None:
"""Delete any excess old devices this user may have."""
- device_ids = await self.store.check_too_many_devices_for_user(user_id)
+ device_ids = await self.store.check_too_many_devices_for_user(user_id, 100)
if not device_ids:
return
+ logger.info("Pruning %d old devices for user %s", len(device_ids), user_id)
+
# We don't want to block and try and delete tonnes of devices at once,
# so we cap the number of devices we delete synchronously.
first_batch, remaining_device_ids = device_ids[:10], device_ids[10:]
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 08ccd46a2b..95d4c0622d 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1569,11 +1569,15 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
return rows
- async def check_too_many_devices_for_user(self, user_id: str) -> List[str]:
+ async def check_too_many_devices_for_user(
+ self, user_id: str, limit: int
+ ) -> List[str]:
"""Check if the user has a lot of devices, and if so return the set of
devices we can prune.
This does *not* return hidden devices or devices with E2E keys.
+
+ Returns at most `limit` number of devices, ordered by last seen.
"""
num_devices = await self.db_pool.simple_select_one_onecol(
@@ -1614,7 +1618,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
# Now fetch the devices to delete.
sql = """
- SELECT DISTINCT device_id FROM devices
+ SELECT device_id FROM devices
LEFT JOIN e2e_device_keys_json USING (user_id, device_id)
WHERE
user_id = ?
@@ -1622,12 +1626,13 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
AND last_seen < ?
AND key_json IS NULL
ORDER BY last_seen
+ LIMIT ?
"""
def check_too_many_devices_for_user_txn(
txn: LoggingTransaction,
) -> List[str]:
- txn.execute(sql, (user_id, max_last_seen))
+ txn.execute(sql, (user_id, max_last_seen, limit))
return [device_id for device_id, in txn]
return await self.db_pool.runInteraction(
|