summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erikj@element.io>2024-06-24 14:15:13 +0100
committerGitHub <noreply@github.com>2024-06-24 14:15:13 +0100
commitcf711ac03cd88b70568b3ac9df4aed4de5b33523 (patch)
tree5903fca3c4fa9ae66c23bafcb4746a8c6365d890 /synapse
parentTidy up integer parsing (#17339) (diff)
downloadsynapse-cf711ac03cd88b70568b3ac9df4aed4de5b33523.tar.xz
Reduce device lists replication traffic. (#17333)
Reduce the replication traffic of device lists, by not sending every
destination that needs to be sent the device list update over
replication. Instead a "hosts to send to have been calculated"
notification over replication, and then federation senders read the
destinations from the DB.

For non federation senders this should heavily reduce the impact of a
user in many large rooms changing a device.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/replication/tcp/client.py19
-rw-r--r--synapse/replication/tcp/streams/_base.py12
-rw-r--r--synapse/storage/databases/main/devices.py93
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py4
4 files changed, 80 insertions, 48 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 2d6d49eed7..3dddbb70b4 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -114,13 +114,19 @@ class ReplicationDataHandler:
         """
         all_room_ids: Set[str] = set()
         if stream_name == DeviceListsStream.NAME:
-            if any(row.entity.startswith("@") and not row.is_signature for row in rows):
+            if any(not row.is_signature and not row.hosts_calculated for row in rows):
                 prev_token = self.store.get_device_stream_token()
                 all_room_ids = await self.store.get_all_device_list_changes(
                     prev_token, token
                 )
                 self.store.device_lists_in_rooms_have_changed(all_room_ids, token)
 
+            # If we're sending federation we need to update the device lists
+            # outbound pokes stream change cache with updated hosts.
+            if self.send_handler and any(row.hosts_calculated for row in rows):
+                hosts = await self.store.get_destinations_for_device(token)
+                self.store.device_lists_outbound_pokes_have_changed(hosts, token)
+
         self.store.process_replication_rows(stream_name, instance_name, token, rows)
         # NOTE: this must be called after process_replication_rows to ensure any
         # cache invalidations are first handled before any stream ID advances.
@@ -433,12 +439,11 @@ class FederationSenderHandler:
             # The entities are either user IDs (starting with '@') whose devices
             # have changed, or remote servers that we need to tell about
             # changes.
-            hosts = {
-                row.entity
-                for row in rows
-                if not row.entity.startswith("@") and not row.is_signature
-            }
-            await self.federation_sender.send_device_messages(hosts, immediate=False)
+            if any(row.hosts_calculated for row in rows):
+                hosts = await self.store.get_destinations_for_device(token)
+                await self.federation_sender.send_device_messages(
+                    hosts, immediate=False
+                )
 
         elif stream_name == ToDeviceStream.NAME:
             # The to_device stream includes stuff to be pushed to both local
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 661206c841..d021904de7 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -549,10 +549,14 @@ class DeviceListsStream(_StreamFromIdGen):
 
     @attr.s(slots=True, frozen=True, auto_attribs=True)
     class DeviceListsStreamRow:
-        entity: str
+        user_id: str
         # Indicates that a user has signed their own device with their user-signing key
         is_signature: bool
 
+        # Indicates if this is a notification that we've calculated the hosts we
+        # need to send the update to.
+        hosts_calculated: bool
+
     NAME = "device_lists"
     ROW_TYPE = DeviceListsStreamRow
 
@@ -594,13 +598,13 @@ class DeviceListsStream(_StreamFromIdGen):
             upper_limit_token = min(upper_limit_token, signatures_to_token)
 
         device_updates = [
-            (stream_id, (entity, False))
-            for stream_id, (entity,) in device_updates
+            (stream_id, (entity, False, hosts))
+            for stream_id, (entity, hosts) in device_updates
             if stream_id <= upper_limit_token
         ]
 
         signatures_updates = [
-            (stream_id, (entity, True))
+            (stream_id, (entity, True, False))
             for stream_id, (entity,) in signatures_updates
             if stream_id <= upper_limit_token
         ]
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 40187496e2..5eeca6165d 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -164,22 +164,24 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             prefilled_cache=user_signature_stream_prefill,
         )
 
-        (
-            device_list_federation_prefill,
-            device_list_federation_list_id,
-        ) = self.db_pool.get_cache_dict(
-            db_conn,
-            "device_lists_outbound_pokes",
-            entity_column="destination",
-            stream_column="stream_id",
-            max_value=device_list_max,
-            limit=10000,
-        )
-        self._device_list_federation_stream_cache = StreamChangeCache(
-            "DeviceListFederationStreamChangeCache",
-            device_list_federation_list_id,
-            prefilled_cache=device_list_federation_prefill,
-        )
+        self._device_list_federation_stream_cache = None
+        if hs.should_send_federation():
+            (
+                device_list_federation_prefill,
+                device_list_federation_list_id,
+            ) = self.db_pool.get_cache_dict(
+                db_conn,
+                "device_lists_outbound_pokes",
+                entity_column="destination",
+                stream_column="stream_id",
+                max_value=device_list_max,
+                limit=10000,
+            )
+            self._device_list_federation_stream_cache = StreamChangeCache(
+                "DeviceListFederationStreamChangeCache",
+                device_list_federation_list_id,
+                prefilled_cache=device_list_federation_prefill,
+            )
 
         if hs.config.worker.run_background_tasks:
             self._clock.looping_call(
@@ -207,23 +209,30 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     ) -> None:
         for row in rows:
             if row.is_signature:
-                self._user_signature_stream_cache.entity_has_changed(row.entity, token)
+                self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
                 continue
 
             # The entities are either user IDs (starting with '@') whose devices
             # have changed, or remote servers that we need to tell about
             # changes.
-            if row.entity.startswith("@"):
-                self._device_list_stream_cache.entity_has_changed(row.entity, token)
-                self.get_cached_devices_for_user.invalidate((row.entity,))
-                self._get_cached_user_device.invalidate((row.entity,))
-                self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
-
-            else:
-                self._device_list_federation_stream_cache.entity_has_changed(
-                    row.entity, token
+            if not row.hosts_calculated:
+                self._device_list_stream_cache.entity_has_changed(row.user_id, token)
+                self.get_cached_devices_for_user.invalidate((row.user_id,))
+                self._get_cached_user_device.invalidate((row.user_id,))
+                self.get_device_list_last_stream_id_for_remote.invalidate(
+                    (row.user_id,)
                 )
 
+    def device_lists_outbound_pokes_have_changed(
+        self, destinations: StrCollection, token: int
+    ) -> None:
+        assert self._device_list_federation_stream_cache is not None
+
+        for destination in destinations:
+            self._device_list_federation_stream_cache.entity_has_changed(
+                destination, token
+            )
+
     def device_lists_in_rooms_have_changed(
         self, room_ids: StrCollection, token: int
     ) -> None:
@@ -363,6 +372,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
               EDU contents.
         """
         now_stream_id = self.get_device_stream_token()
+        if from_stream_id == now_stream_id:
+            return now_stream_id, []
+
+        if self._device_list_federation_stream_cache is None:
+            raise Exception("Func can only be used on federation senders")
 
         has_changed = self._device_list_federation_stream_cache.has_entity_changed(
             destination, int(from_stream_id)
@@ -1018,10 +1032,10 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             # This query Does The Right Thing where it'll correctly apply the
             # bounds to the inner queries.
             sql = """
-                SELECT stream_id, entity FROM (
-                    SELECT stream_id, user_id AS entity FROM device_lists_stream
+                SELECT stream_id, user_id, hosts FROM (
+                    SELECT stream_id, user_id, false AS hosts FROM device_lists_stream
                     UNION ALL
-                    SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
+                    SELECT DISTINCT stream_id, user_id, true AS hosts FROM device_lists_outbound_pokes
                 ) AS e
                 WHERE ? < stream_id AND stream_id <= ?
                 ORDER BY stream_id ASC
@@ -1577,6 +1591,14 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             get_device_list_changes_in_room_txn,
         )
 
+    async def get_destinations_for_device(self, stream_id: int) -> StrCollection:
+        return await self.db_pool.simple_select_onecol(
+            table="device_lists_outbound_pokes",
+            keyvalues={"stream_id": stream_id},
+            retcol="destination",
+            desc="get_destinations_for_device",
+        )
+
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
     def __init__(
@@ -2112,12 +2134,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         stream_ids: List[int],
         context: Optional[Dict[str, str]],
     ) -> None:
-        for host in hosts:
-            txn.call_after(
-                self._device_list_federation_stream_cache.entity_has_changed,
-                host,
-                stream_ids[-1],
-            )
+        if self._device_list_federation_stream_cache:
+            for host in hosts:
+                txn.call_after(
+                    self._device_list_federation_stream_cache.entity_has_changed,
+                    host,
+                    stream_ids[-1],
+                )
 
         now = self._clock.time_msec()
         stream_id_iterator = iter(stream_ids)
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 38d8785faa..9e6c9561ae 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -123,9 +123,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         if stream_name == DeviceListsStream.NAME:
             for row in rows:
                 assert isinstance(row, DeviceListsStream.DeviceListsStreamRow)
-                if row.entity.startswith("@"):
+                if not row.hosts_calculated:
                     self._get_e2e_device_keys_for_federation_query_inner.invalidate(
-                        (row.entity,)
+                        (row.user_id,)
                     )
 
         super().process_replication_rows(stream_name, instance_name, token, rows)