summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/7159.bugfix1
-rw-r--r--synapse/handlers/federation.py25
-rw-r--r--synapse/storage/data_stores/main/devices.py71
-rw-r--r--synapse/util/stringutils.py21
-rw-r--r--tests/federation/test_federation_sender.py92
5 files changed, 173 insertions, 37 deletions
diff --git a/changelog.d/7159.bugfix b/changelog.d/7159.bugfix
new file mode 100644
index 0000000000..1b341b127b
--- /dev/null
+++ b/changelog.d/7159.bugfix
@@ -0,0 +1 @@
+Fix excessive CPU usage by `prune_old_outbound_device_pokes` job.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 38ab6a8fc3..c7aa7acf3b 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -49,6 +49,7 @@ from synapse.event_auth import auth_types_for_event
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.events.validator import EventValidator
+from synapse.handlers._base import BaseHandler
 from synapse.logging.context import (
     make_deferred_yieldable,
     nested_logging_context,
@@ -69,10 +70,9 @@ from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id
 from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.distributor import user_joined_room
 from synapse.util.retryutils import NotRetryingDestination
+from synapse.util.stringutils import shortstr
 from synapse.visibility import filter_events_for_server
 
-from ._base import BaseHandler
-
 logger = logging.getLogger(__name__)
 
 
@@ -93,27 +93,6 @@ class _NewEventInfo:
     auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None)
 
 
-def shortstr(iterable, maxitems=5):
-    """If iterable has maxitems or fewer, return the stringification of a list
-    containing those items.
-
-    Otherwise, return the stringification of a a list with the first maxitems items,
-    followed by "...".
-
-    Args:
-        iterable (Iterable): iterable to truncate
-        maxitems (int): number of items to return before truncating
-
-    Returns:
-        unicode
-    """
-
-    items = list(itertools.islice(iterable, maxitems + 1))
-    if len(items) <= maxitems:
-        return str(items)
-    return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
-
-
 class FederationHandler(BaseHandler):
     """Handles events that originated from federation.
         Responsible for:
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 2d47cfd131..3140e1b722 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -41,6 +41,7 @@ from synapse.util.caches.descriptors import (
     cachedList,
 )
 from synapse.util.iterutils import batch_iter
+from synapse.util.stringutils import shortstr
 
 logger = logging.getLogger(__name__)
 
@@ -1092,18 +1093,47 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             ],
         )
 
-    def _prune_old_outbound_device_pokes(self):
+    def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000):
         """Delete old entries out of the device_lists_outbound_pokes to ensure
-        that we don't fill up due to dead servers. We keep one entry per
-        (destination, user_id) tuple to ensure that the prev_ids remain correct
-        if the server does come back.
+        that we don't fill up due to dead servers.
+
+        Normally, we try to send device updates as a delta since a previous known point:
+        this is done by setting the prev_id in the m.device_list_update EDU. However,
+        for that to work, we have to have a complete record of each change to
+        each device, which can add up to quite a lot of data.
+
+        An alternative mechanism is that, if the remote server sees that it has missed
+        an entry in the stream_id sequence for a given user, it will request a full
+        list of that user's devices. Hence, we can reduce the amount of data we have to
+        store (and transmit in some future transaction), by clearing almost everything
+        for a given destination out of the database, and having the remote server
+        resync.
+
+        All we need to do is make sure we keep at least one row for each
+        (user, destination) pair, to remind us to send a m.device_list_update EDU for
+        that user when the destination comes back. It doesn't matter which device
+        we keep.
         """
-        yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000
+        yesterday = self._clock.time_msec() - prune_age
 
         def _prune_txn(txn):
+            # look for (user, destination) pairs which have an update older than
+            # the cutoff.
+            #
+            # For each pair, we also need to know the most recent stream_id, and
+            # an arbitrary device_id at that stream_id.
             select_sql = """
-                SELECT destination, user_id, max(stream_id) as stream_id
-                FROM device_lists_outbound_pokes
+            SELECT
+                dlop1.destination,
+                dlop1.user_id,
+                MAX(dlop1.stream_id) AS stream_id,
+                (SELECT MIN(dlop2.device_id) AS device_id FROM
+                    device_lists_outbound_pokes dlop2
+                    WHERE dlop2.destination = dlop1.destination AND
+                      dlop2.user_id=dlop1.user_id AND
+                      dlop2.stream_id=MAX(dlop1.stream_id)
+                )
+            FROM device_lists_outbound_pokes dlop1
                 GROUP BY destination, user_id
                 HAVING min(ts) < ? AND count(*) > 1
             """
@@ -1114,14 +1144,29 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             if not rows:
                 return
 
+            logger.info(
+                "Pruning old outbound device list updates for %i users/destinations: %s",
+                len(rows),
+                shortstr((row[0], row[1]) for row in rows),
+            )
+
+            # we want to keep the update with the highest stream_id for each user.
+            #
+            # there might be more than one update (with different device_ids) with the
+            # same stream_id, so we also delete all but one rows with the max stream id.
             delete_sql = """
                 DELETE FROM device_lists_outbound_pokes
-                WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ?
+                WHERE destination = ? AND user_id = ? AND (
+                    stream_id < ? OR
+                    (stream_id = ? AND device_id != ?)
+                )
             """
-
-            txn.executemany(
-                delete_sql, ((yesterday, row[0], row[1], row[2]) for row in rows)
-            )
+            count = 0
+            for (destination, user_id, stream_id, device_id) in rows:
+                txn.execute(
+                    delete_sql, (destination, user_id, stream_id, stream_id, device_id)
+                )
+                count += txn.rowcount
 
             # Since we've deleted unsent deltas, we need to remove the entry
             # of last successful sent so that the prev_ids are correctly set.
@@ -1131,7 +1176,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             """
             txn.executemany(sql, ((row[0], row[1]) for row in rows))
 
-            logger.info("Pruned %d device list outbound pokes", txn.rowcount)
+            logger.info("Pruned %d device list outbound pokes", count)
 
         return run_as_background_process(
             "prune_old_outbound_device_pokes",
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 2c0dcb5208..6899bcb788 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -13,10 +13,11 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import itertools
 import random
 import re
 import string
+from collections import Iterable
 
 import six
 from six import PY2, PY3
@@ -126,3 +127,21 @@ def assert_valid_client_secret(client_secret):
         raise SynapseError(
             400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM
         )
+
+
+def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
+    """If iterable has maxitems or fewer, return the stringification of a list
+    containing those items.
+
+    Otherwise, return the stringification of a a list with the first maxitems items,
+    followed by "...".
+
+    Args:
+        iterable: iterable to truncate
+        maxitems: number of items to return before truncating
+    """
+
+    items = list(itertools.islice(iterable, maxitems + 1))
+    if len(items) <= maxitems:
+        return str(items)
+    return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 7763b12159..a5fe5c6880 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -370,6 +370,98 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         devices = {edu["content"]["device_id"] for edu in self.edus}
         self.assertEqual({"D1", "D2", "D3"}, devices)
 
+    def test_prune_outbound_device_pokes1(self):
+        """If a destination is unreachable, and the updates are pruned, we should get
+        a single update.
+
+        This case tests the behaviour when the server has never been reachable.
+        """
+        mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+        mock_send_txn.side_effect = lambda t, cb: defer.fail("fail")
+
+        # create devices
+        u1 = self.register_user("user", "pass")
+        self.login("user", "pass", device_id="D1")
+        self.login("user", "pass", device_id="D2")
+        self.login("user", "pass", device_id="D3")
+
+        # delete them again
+        self.get_success(
+            self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
+        )
+
+        self.assertGreaterEqual(mock_send_txn.call_count, 4)
+
+        # run the prune job
+        self.reactor.advance(10)
+        self.get_success(
+            self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1)
+        )
+
+        # recover the server
+        mock_send_txn.side_effect = self.record_transaction
+        self.hs.get_federation_sender().send_device_messages("host2")
+        self.pump()
+
+        # there should be a single update for this user.
+        self.assertEqual(len(self.edus), 1)
+        edu = self.edus.pop(0)
+        self.assertEqual(edu["edu_type"], "m.device_list_update")
+        c = edu["content"]
+
+        # synapse uses an empty prev_id list to indicate "needs a full resync".
+        self.assertEqual(c["prev_id"], [])
+
+    def test_prune_outbound_device_pokes2(self):
+        """If a destination is unreachable, and the updates are pruned, we should get
+        a single update.
+
+        This case tests the behaviour when the server was reachable, but then goes
+        offline.
+        """
+
+        # create first device
+        u1 = self.register_user("user", "pass")
+        self.login("user", "pass", device_id="D1")
+
+        # expect the update EDU
+        self.assertEqual(len(self.edus), 1)
+        self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
+
+        # now the server goes offline
+        mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+        mock_send_txn.side_effect = lambda t, cb: defer.fail("fail")
+
+        self.login("user", "pass", device_id="D2")
+        self.login("user", "pass", device_id="D3")
+
+        # delete them again
+        self.get_success(
+            self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
+        )
+
+        self.assertGreaterEqual(mock_send_txn.call_count, 3)
+
+        # run the prune job
+        self.reactor.advance(10)
+        self.get_success(
+            self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1)
+        )
+
+        # recover the server
+        mock_send_txn.side_effect = self.record_transaction
+        self.hs.get_federation_sender().send_device_messages("host2")
+        self.pump()
+
+        # ... and we should get a single update for this user.
+        self.assertEqual(len(self.edus), 1)
+        edu = self.edus.pop(0)
+        self.assertEqual(edu["edu_type"], "m.device_list_update")
+        c = edu["content"]
+
+        # synapse uses an empty prev_id list to indicate "needs a full resync".
+        self.assertEqual(c["prev_id"], [])
+
     def check_device_update_edu(
         self,
         edu: JsonDict,