summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/17333.misc1
-rw-r--r--synapse/replication/tcp/client.py19
-rw-r--r--synapse/replication/tcp/streams/_base.py12
-rw-r--r--synapse/storage/databases/main/devices.py93
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py4
-rw-r--r--tests/storage/test_devices.py8
6 files changed, 89 insertions, 48 deletions
diff --git a/changelog.d/17333.misc b/changelog.d/17333.misc
new file mode 100644
index 0000000000..d3ef0b3777
--- /dev/null
+++ b/changelog.d/17333.misc
@@ -0,0 +1 @@
+Handle device lists notifications for large accounts more efficiently in worker mode.
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 2d6d49eed7..3dddbb70b4 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -114,13 +114,19 @@ class ReplicationDataHandler:
         """
         all_room_ids: Set[str] = set()
         if stream_name == DeviceListsStream.NAME:
-            if any(row.entity.startswith("@") and not row.is_signature for row in rows):
+            if any(not row.is_signature and not row.hosts_calculated for row in rows):
                 prev_token = self.store.get_device_stream_token()
                 all_room_ids = await self.store.get_all_device_list_changes(
                     prev_token, token
                 )
                 self.store.device_lists_in_rooms_have_changed(all_room_ids, token)
 
+            # If we're sending federation we need to update the device lists
+            # outbound pokes stream change cache with updated hosts.
+            if self.send_handler and any(row.hosts_calculated for row in rows):
+                hosts = await self.store.get_destinations_for_device(token)
+                self.store.device_lists_outbound_pokes_have_changed(hosts, token)
+
         self.store.process_replication_rows(stream_name, instance_name, token, rows)
         # NOTE: this must be called after process_replication_rows to ensure any
         # cache invalidations are first handled before any stream ID advances.
@@ -433,12 +439,11 @@ class FederationSenderHandler:
             # The entities are either user IDs (starting with '@') whose devices
             # have changed, or remote servers that we need to tell about
             # changes.
-            hosts = {
-                row.entity
-                for row in rows
-                if not row.entity.startswith("@") and not row.is_signature
-            }
-            await self.federation_sender.send_device_messages(hosts, immediate=False)
+            if any(row.hosts_calculated for row in rows):
+                hosts = await self.store.get_destinations_for_device(token)
+                await self.federation_sender.send_device_messages(
+                    hosts, immediate=False
+                )
 
         elif stream_name == ToDeviceStream.NAME:
             # The to_device stream includes stuff to be pushed to both local
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 661206c841..d021904de7 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -549,10 +549,14 @@ class DeviceListsStream(_StreamFromIdGen):
 
     @attr.s(slots=True, frozen=True, auto_attribs=True)
     class DeviceListsStreamRow:
-        entity: str
+        user_id: str
         # Indicates that a user has signed their own device with their user-signing key
         is_signature: bool
 
+        # Indicates if this is a notification that we've calculated the hosts we
+        # need to send the update to.
+        hosts_calculated: bool
+
     NAME = "device_lists"
     ROW_TYPE = DeviceListsStreamRow
 
@@ -594,13 +598,13 @@ class DeviceListsStream(_StreamFromIdGen):
             upper_limit_token = min(upper_limit_token, signatures_to_token)
 
         device_updates = [
-            (stream_id, (entity, False))
-            for stream_id, (entity,) in device_updates
+            (stream_id, (entity, False, hosts))
+            for stream_id, (entity, hosts) in device_updates
             if stream_id <= upper_limit_token
         ]
 
         signatures_updates = [
-            (stream_id, (entity, True))
+            (stream_id, (entity, True, False))
             for stream_id, (entity,) in signatures_updates
             if stream_id <= upper_limit_token
         ]
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 40187496e2..5eeca6165d 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -164,22 +164,24 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             prefilled_cache=user_signature_stream_prefill,
         )
 
-        (
-            device_list_federation_prefill,
-            device_list_federation_list_id,
-        ) = self.db_pool.get_cache_dict(
-            db_conn,
-            "device_lists_outbound_pokes",
-            entity_column="destination",
-            stream_column="stream_id",
-            max_value=device_list_max,
-            limit=10000,
-        )
-        self._device_list_federation_stream_cache = StreamChangeCache(
-            "DeviceListFederationStreamChangeCache",
-            device_list_federation_list_id,
-            prefilled_cache=device_list_federation_prefill,
-        )
+        self._device_list_federation_stream_cache = None
+        if hs.should_send_federation():
+            (
+                device_list_federation_prefill,
+                device_list_federation_list_id,
+            ) = self.db_pool.get_cache_dict(
+                db_conn,
+                "device_lists_outbound_pokes",
+                entity_column="destination",
+                stream_column="stream_id",
+                max_value=device_list_max,
+                limit=10000,
+            )
+            self._device_list_federation_stream_cache = StreamChangeCache(
+                "DeviceListFederationStreamChangeCache",
+                device_list_federation_list_id,
+                prefilled_cache=device_list_federation_prefill,
+            )
 
         if hs.config.worker.run_background_tasks:
             self._clock.looping_call(
@@ -207,23 +209,30 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     ) -> None:
         for row in rows:
             if row.is_signature:
-                self._user_signature_stream_cache.entity_has_changed(row.entity, token)
+                self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
                 continue
 
             # The entities are either user IDs (starting with '@') whose devices
             # have changed, or remote servers that we need to tell about
             # changes.
-            if row.entity.startswith("@"):
-                self._device_list_stream_cache.entity_has_changed(row.entity, token)
-                self.get_cached_devices_for_user.invalidate((row.entity,))
-                self._get_cached_user_device.invalidate((row.entity,))
-                self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
-
-            else:
-                self._device_list_federation_stream_cache.entity_has_changed(
-                    row.entity, token
+            if not row.hosts_calculated:
+                self._device_list_stream_cache.entity_has_changed(row.user_id, token)
+                self.get_cached_devices_for_user.invalidate((row.user_id,))
+                self._get_cached_user_device.invalidate((row.user_id,))
+                self.get_device_list_last_stream_id_for_remote.invalidate(
+                    (row.user_id,)
                 )
 
+    def device_lists_outbound_pokes_have_changed(
+        self, destinations: StrCollection, token: int
+    ) -> None:
+        assert self._device_list_federation_stream_cache is not None
+
+        for destination in destinations:
+            self._device_list_federation_stream_cache.entity_has_changed(
+                destination, token
+            )
+
     def device_lists_in_rooms_have_changed(
         self, room_ids: StrCollection, token: int
     ) -> None:
@@ -363,6 +372,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
               EDU contents.
         """
         now_stream_id = self.get_device_stream_token()
+        if from_stream_id == now_stream_id:
+            return now_stream_id, []
+
+        if self._device_list_federation_stream_cache is None:
+            raise Exception("Func can only be used on federation senders")
 
         has_changed = self._device_list_federation_stream_cache.has_entity_changed(
             destination, int(from_stream_id)
@@ -1018,10 +1032,10 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             # This query Does The Right Thing where it'll correctly apply the
             # bounds to the inner queries.
             sql = """
-                SELECT stream_id, entity FROM (
-                    SELECT stream_id, user_id AS entity FROM device_lists_stream
+                SELECT stream_id, user_id, hosts FROM (
+                    SELECT stream_id, user_id, false AS hosts FROM device_lists_stream
                     UNION ALL
-                    SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
+                    SELECT DISTINCT stream_id, user_id, true AS hosts FROM device_lists_outbound_pokes
                 ) AS e
                 WHERE ? < stream_id AND stream_id <= ?
                 ORDER BY stream_id ASC
@@ -1577,6 +1591,14 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             get_device_list_changes_in_room_txn,
         )
 
+    async def get_destinations_for_device(self, stream_id: int) -> StrCollection:
+        return await self.db_pool.simple_select_onecol(
+            table="device_lists_outbound_pokes",
+            keyvalues={"stream_id": stream_id},
+            retcol="destination",
+            desc="get_destinations_for_device",
+        )
+
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
     def __init__(
@@ -2112,12 +2134,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         stream_ids: List[int],
         context: Optional[Dict[str, str]],
     ) -> None:
-        for host in hosts:
-            txn.call_after(
-                self._device_list_federation_stream_cache.entity_has_changed,
-                host,
-                stream_ids[-1],
-            )
+        if self._device_list_federation_stream_cache:
+            for host in hosts:
+                txn.call_after(
+                    self._device_list_federation_stream_cache.entity_has_changed,
+                    host,
+                    stream_ids[-1],
+                )
 
         now = self._clock.time_msec()
         stream_id_iterator = iter(stream_ids)
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 38d8785faa..9e6c9561ae 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -123,9 +123,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         if stream_name == DeviceListsStream.NAME:
             for row in rows:
                 assert isinstance(row, DeviceListsStream.DeviceListsStreamRow)
-                if row.entity.startswith("@"):
+                if not row.hosts_calculated:
                     self._get_e2e_device_keys_for_federation_query_inner.invalidate(
-                        (row.entity,)
+                        (row.user_id,)
                     )
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 7f975d04ff..ba01b038ab 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -36,6 +36,14 @@ class DeviceStoreTestCase(HomeserverTestCase):
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
 
+    def default_config(self) -> JsonDict:
+        config = super().default_config()
+
+        # We 'enable' federation otherwise `get_device_updates_by_remote` will
+        # throw an exception.
+        config["federation_sender_instances"] = ["master"]
+        return config
+
     def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None:
         """Add a device list change for the given device to
         `device_lists_outbound_pokes` table.