diff --git a/changelog.d/7159.bugfix b/changelog.d/7159.bugfix
new file mode 100644
index 0000000000..1b341b127b
--- /dev/null
+++ b/changelog.d/7159.bugfix
@@ -0,0 +1 @@
+Fix excessive CPU usage by `prune_old_outbound_device_pokes` job.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 38ab6a8fc3..c7aa7acf3b 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -49,6 +49,7 @@ from synapse.event_auth import auth_types_for_event
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
+from synapse.handlers._base import BaseHandler
from synapse.logging.context import (
make_deferred_yieldable,
nested_logging_context,
@@ -69,10 +70,9 @@ from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination
+from synapse.util.stringutils import shortstr
from synapse.visibility import filter_events_for_server
-from ._base import BaseHandler
-
logger = logging.getLogger(__name__)
@@ -93,27 +93,6 @@ class _NewEventInfo:
auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None)
-def shortstr(iterable, maxitems=5):
- """If iterable has maxitems or fewer, return the stringification of a list
- containing those items.
-
- Otherwise, return the stringification of a a list with the first maxitems items,
- followed by "...".
-
- Args:
- iterable (Iterable): iterable to truncate
- maxitems (int): number of items to return before truncating
-
- Returns:
- unicode
- """
-
- items = list(itertools.islice(iterable, maxitems + 1))
- if len(items) <= maxitems:
- return str(items)
- return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
-
-
class FederationHandler(BaseHandler):
"""Handles events that originated from federation.
Responsible for:
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 2d47cfd131..3140e1b722 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -41,6 +41,7 @@ from synapse.util.caches.descriptors import (
cachedList,
)
from synapse.util.iterutils import batch_iter
+from synapse.util.stringutils import shortstr
logger = logging.getLogger(__name__)
@@ -1092,18 +1093,47 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
],
)
- def _prune_old_outbound_device_pokes(self):
+ def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000):
"""Delete old entries out of the device_lists_outbound_pokes to ensure
- that we don't fill up due to dead servers. We keep one entry per
- (destination, user_id) tuple to ensure that the prev_ids remain correct
- if the server does come back.
+ that we don't fill up due to dead servers.
+
+ Normally, we try to send device updates as a delta since a previous known point:
+ this is done by setting the prev_id in the m.device_list_update EDU. However,
+ for that to work, we have to have a complete record of each change to
+ each device, which can add up to quite a lot of data.
+
+ An alternative mechanism is that, if the remote server sees that it has missed
+ an entry in the stream_id sequence for a given user, it will request a full
+ list of that user's devices. Hence, we can reduce the amount of data we have to
+ store (and transmit in some future transaction), by clearing almost everything
+ for a given destination out of the database, and having the remote server
+ resync.
+
+ All we need to do is make sure we keep at least one row for each
+ (user, destination) pair, to remind us to send a m.device_list_update EDU for
+ that user when the destination comes back. It doesn't matter which device
+ we keep.
"""
- yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000
+ yesterday = self._clock.time_msec() - prune_age
def _prune_txn(txn):
+ # look for (user, destination) pairs which have an update older than
+ # the cutoff.
+ #
+ # For each pair, we also need to know the most recent stream_id, and
+ # an arbitrary device_id at that stream_id.
select_sql = """
- SELECT destination, user_id, max(stream_id) as stream_id
- FROM device_lists_outbound_pokes
+ SELECT
+ dlop1.destination,
+ dlop1.user_id,
+ MAX(dlop1.stream_id) AS stream_id,
+ (SELECT MIN(dlop2.device_id) AS device_id FROM
+ device_lists_outbound_pokes dlop2
+ WHERE dlop2.destination = dlop1.destination AND
+ dlop2.user_id=dlop1.user_id AND
+ dlop2.stream_id=MAX(dlop1.stream_id)
+ )
+ FROM device_lists_outbound_pokes dlop1
GROUP BY destination, user_id
HAVING min(ts) < ? AND count(*) > 1
"""
@@ -1114,14 +1144,29 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not rows:
return
+ logger.info(
+ "Pruning old outbound device list updates for %i users/destinations: %s",
+ len(rows),
+ shortstr((row[0], row[1]) for row in rows),
+ )
+
+ # we want to keep the update with the highest stream_id for each user.
+ #
+ # there might be more than one update (with different device_ids) with the
+ # same stream_id, so we also delete all but one rows with the max stream id.
delete_sql = """
DELETE FROM device_lists_outbound_pokes
- WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ?
+ WHERE destination = ? AND user_id = ? AND (
+ stream_id < ? OR
+ (stream_id = ? AND device_id != ?)
+ )
"""
-
- txn.executemany(
- delete_sql, ((yesterday, row[0], row[1], row[2]) for row in rows)
- )
+ count = 0
+ for (destination, user_id, stream_id, device_id) in rows:
+ txn.execute(
+ delete_sql, (destination, user_id, stream_id, stream_id, device_id)
+ )
+ count += txn.rowcount
# Since we've deleted unsent deltas, we need to remove the entry
# of last successful sent so that the prev_ids are correctly set.
@@ -1131,7 +1176,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"""
txn.executemany(sql, ((row[0], row[1]) for row in rows))
- logger.info("Pruned %d device list outbound pokes", txn.rowcount)
+ logger.info("Pruned %d device list outbound pokes", count)
return run_as_background_process(
"prune_old_outbound_device_pokes",
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 2c0dcb5208..6899bcb788 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -13,10 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import itertools
import random
import re
import string
+from collections import Iterable
import six
from six import PY2, PY3
@@ -126,3 +127,21 @@ def assert_valid_client_secret(client_secret):
raise SynapseError(
400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM
)
+
+
+def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
+ """If iterable has maxitems or fewer, return the stringification of a list
+ containing those items.
+
+ Otherwise, return the stringification of a a list with the first maxitems items,
+ followed by "...".
+
+ Args:
+ iterable: iterable to truncate
+ maxitems: number of items to return before truncating
+ """
+
+ items = list(itertools.islice(iterable, maxitems + 1))
+ if len(items) <= maxitems:
+ return str(items)
+ return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 7763b12159..a5fe5c6880 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -370,6 +370,98 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
devices = {edu["content"]["device_id"] for edu in self.edus}
self.assertEqual({"D1", "D2", "D3"}, devices)
+ def test_prune_outbound_device_pokes1(self):
+ """If a destination is unreachable, and the updates are pruned, we should get
+ a single update.
+
+ This case tests the behaviour when the server has never been reachable.
+ """
+ mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+ mock_send_txn.side_effect = lambda t, cb: defer.fail("fail")
+
+ # create devices
+ u1 = self.register_user("user", "pass")
+ self.login("user", "pass", device_id="D1")
+ self.login("user", "pass", device_id="D2")
+ self.login("user", "pass", device_id="D3")
+
+ # delete them again
+ self.get_success(
+ self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
+ )
+
+ self.assertGreaterEqual(mock_send_txn.call_count, 4)
+
+ # run the prune job
+ self.reactor.advance(10)
+ self.get_success(
+ self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1)
+ )
+
+ # recover the server
+ mock_send_txn.side_effect = self.record_transaction
+ self.hs.get_federation_sender().send_device_messages("host2")
+ self.pump()
+
+ # there should be a single update for this user.
+ self.assertEqual(len(self.edus), 1)
+ edu = self.edus.pop(0)
+ self.assertEqual(edu["edu_type"], "m.device_list_update")
+ c = edu["content"]
+
+ # synapse uses an empty prev_id list to indicate "needs a full resync".
+ self.assertEqual(c["prev_id"], [])
+
+ def test_prune_outbound_device_pokes2(self):
+ """If a destination is unreachable, and the updates are pruned, we should get
+ a single update.
+
+ This case tests the behaviour when the server was reachable, but then goes
+ offline.
+ """
+
+ # create first device
+ u1 = self.register_user("user", "pass")
+ self.login("user", "pass", device_id="D1")
+
+ # expect the update EDU
+ self.assertEqual(len(self.edus), 1)
+ self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
+
+ # now the server goes offline
+ mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+ mock_send_txn.side_effect = lambda t, cb: defer.fail("fail")
+
+ self.login("user", "pass", device_id="D2")
+ self.login("user", "pass", device_id="D3")
+
+ # delete them again
+ self.get_success(
+ self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
+ )
+
+ self.assertGreaterEqual(mock_send_txn.call_count, 3)
+
+ # run the prune job
+ self.reactor.advance(10)
+ self.get_success(
+ self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1)
+ )
+
+ # recover the server
+ mock_send_txn.side_effect = self.record_transaction
+ self.hs.get_federation_sender().send_device_messages("host2")
+ self.pump()
+
+ # ... and we should get a single update for this user.
+ self.assertEqual(len(self.edus), 1)
+ edu = self.edus.pop(0)
+ self.assertEqual(edu["edu_type"], "m.device_list_update")
+ c = edu["content"]
+
+ # synapse uses an empty prev_id list to indicate "needs a full resync".
+ self.assertEqual(c["prev_id"], [])
+
def check_device_update_edu(
self,
edu: JsonDict,
|