summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorErik Johnston <erikj@element.io>2024-06-24 14:15:13 +0100
committerGitHub <noreply@github.com>2024-06-24 14:15:13 +0100
commitcf711ac03cd88b70568b3ac9df4aed4de5b33523 (patch)
tree5903fca3c4fa9ae66c23bafcb4746a8c6365d890 /synapse/storage
parentTidy up integer parsing (#17339) (diff)
downloadsynapse-cf711ac03cd88b70568b3ac9df4aed4de5b33523.tar.xz
Reduce device lists replication traffic. (#17333)
Reduce the replication traffic of device lists, by not sending every
destination that needs to be sent the device list update over
replication. Instead a "hosts to send to have been calculated"
notification over replication, and then federation senders read the
destinations from the DB.

For non federation senders this should heavily reduce the impact of a
user in many large rooms changing a device.
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/devices.py93
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py4
2 files changed, 60 insertions, 37 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 40187496e2..5eeca6165d 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -164,22 +164,24 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             prefilled_cache=user_signature_stream_prefill,
         )
 
-        (
-            device_list_federation_prefill,
-            device_list_federation_list_id,
-        ) = self.db_pool.get_cache_dict(
-            db_conn,
-            "device_lists_outbound_pokes",
-            entity_column="destination",
-            stream_column="stream_id",
-            max_value=device_list_max,
-            limit=10000,
-        )
-        self._device_list_federation_stream_cache = StreamChangeCache(
-            "DeviceListFederationStreamChangeCache",
-            device_list_federation_list_id,
-            prefilled_cache=device_list_federation_prefill,
-        )
+        self._device_list_federation_stream_cache = None
+        if hs.should_send_federation():
+            (
+                device_list_federation_prefill,
+                device_list_federation_list_id,
+            ) = self.db_pool.get_cache_dict(
+                db_conn,
+                "device_lists_outbound_pokes",
+                entity_column="destination",
+                stream_column="stream_id",
+                max_value=device_list_max,
+                limit=10000,
+            )
+            self._device_list_federation_stream_cache = StreamChangeCache(
+                "DeviceListFederationStreamChangeCache",
+                device_list_federation_list_id,
+                prefilled_cache=device_list_federation_prefill,
+            )
 
         if hs.config.worker.run_background_tasks:
             self._clock.looping_call(
@@ -207,23 +209,30 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     ) -> None:
         for row in rows:
             if row.is_signature:
-                self._user_signature_stream_cache.entity_has_changed(row.entity, token)
+                self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
                 continue
 
             # The entities are either user IDs (starting with '@') whose devices
             # have changed, or remote servers that we need to tell about
             # changes.
-            if row.entity.startswith("@"):
-                self._device_list_stream_cache.entity_has_changed(row.entity, token)
-                self.get_cached_devices_for_user.invalidate((row.entity,))
-                self._get_cached_user_device.invalidate((row.entity,))
-                self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
-
-            else:
-                self._device_list_federation_stream_cache.entity_has_changed(
-                    row.entity, token
+            if not row.hosts_calculated:
+                self._device_list_stream_cache.entity_has_changed(row.user_id, token)
+                self.get_cached_devices_for_user.invalidate((row.user_id,))
+                self._get_cached_user_device.invalidate((row.user_id,))
+                self.get_device_list_last_stream_id_for_remote.invalidate(
+                    (row.user_id,)
                 )
 
+    def device_lists_outbound_pokes_have_changed(
+        self, destinations: StrCollection, token: int
+    ) -> None:
+        assert self._device_list_federation_stream_cache is not None
+
+        for destination in destinations:
+            self._device_list_federation_stream_cache.entity_has_changed(
+                destination, token
+            )
+
     def device_lists_in_rooms_have_changed(
         self, room_ids: StrCollection, token: int
     ) -> None:
@@ -363,6 +372,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
               EDU contents.
         """
         now_stream_id = self.get_device_stream_token()
+        if from_stream_id == now_stream_id:
+            return now_stream_id, []
+
+        if self._device_list_federation_stream_cache is None:
+            raise Exception("Func can only be used on federation senders")
 
         has_changed = self._device_list_federation_stream_cache.has_entity_changed(
             destination, int(from_stream_id)
@@ -1018,10 +1032,10 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             # This query Does The Right Thing where it'll correctly apply the
             # bounds to the inner queries.
             sql = """
-                SELECT stream_id, entity FROM (
-                    SELECT stream_id, user_id AS entity FROM device_lists_stream
+                SELECT stream_id, user_id, hosts FROM (
+                    SELECT stream_id, user_id, false AS hosts FROM device_lists_stream
                     UNION ALL
-                    SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
+                    SELECT DISTINCT stream_id, user_id, true AS hosts FROM device_lists_outbound_pokes
                 ) AS e
                 WHERE ? < stream_id AND stream_id <= ?
                 ORDER BY stream_id ASC
@@ -1577,6 +1591,14 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             get_device_list_changes_in_room_txn,
         )
 
+    async def get_destinations_for_device(self, stream_id: int) -> StrCollection:
+        return await self.db_pool.simple_select_onecol(
+            table="device_lists_outbound_pokes",
+            keyvalues={"stream_id": stream_id},
+            retcol="destination",
+            desc="get_destinations_for_device",
+        )
+
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
     def __init__(
@@ -2112,12 +2134,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         stream_ids: List[int],
         context: Optional[Dict[str, str]],
     ) -> None:
-        for host in hosts:
-            txn.call_after(
-                self._device_list_federation_stream_cache.entity_has_changed,
-                host,
-                stream_ids[-1],
-            )
+        if self._device_list_federation_stream_cache:
+            for host in hosts:
+                txn.call_after(
+                    self._device_list_federation_stream_cache.entity_has_changed,
+                    host,
+                    stream_ids[-1],
+                )
 
         now = self._clock.time_msec()
         stream_id_iterator = iter(stream_ids)
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 38d8785faa..9e6c9561ae 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -123,9 +123,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         if stream_name == DeviceListsStream.NAME:
             for row in rows:
                 assert isinstance(row, DeviceListsStream.DeviceListsStreamRow)
-                if row.entity.startswith("@"):
+                if not row.hosts_calculated:
                     self._get_e2e_device_keys_for_federation_query_inner.invalidate(
-                        (row.entity,)
+                        (row.user_id,)
                     )
 
         super().process_replication_rows(stream_name, instance_name, token, rows)