summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/17333.misc1
-rw-r--r--synapse/replication/tcp/client.py19
-rw-r--r--synapse/replication/tcp/streams/_base.py12
-rw-r--r--synapse/storage/databases/main/devices.py93
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py4
-rw-r--r--tests/storage/test_devices.py8
6 files changed, 48 insertions, 89 deletions
diff --git a/changelog.d/17333.misc b/changelog.d/17333.misc
deleted file mode 100644
index d3ef0b3777..0000000000
--- a/changelog.d/17333.misc
+++ /dev/null
@@ -1 +0,0 @@
-Handle device lists notifications for large accounts more efficiently in worker mode.
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 3dddbb70b4..2d6d49eed7 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -114,19 +114,13 @@ class ReplicationDataHandler:
         """
         all_room_ids: Set[str] = set()
         if stream_name == DeviceListsStream.NAME:
-            if any(not row.is_signature and not row.hosts_calculated for row in rows):
+            if any(row.entity.startswith("@") and not row.is_signature for row in rows):
                 prev_token = self.store.get_device_stream_token()
                 all_room_ids = await self.store.get_all_device_list_changes(
                     prev_token, token
                 )
                 self.store.device_lists_in_rooms_have_changed(all_room_ids, token)
 
-            # If we're sending federation we need to update the device lists
-            # outbound pokes stream change cache with updated hosts.
-            if self.send_handler and any(row.hosts_calculated for row in rows):
-                hosts = await self.store.get_destinations_for_device(token)
-                self.store.device_lists_outbound_pokes_have_changed(hosts, token)
-
         self.store.process_replication_rows(stream_name, instance_name, token, rows)
         # NOTE: this must be called after process_replication_rows to ensure any
         # cache invalidations are first handled before any stream ID advances.
@@ -439,11 +433,12 @@ class FederationSenderHandler:
             # The entities are either user IDs (starting with '@') whose devices
             # have changed, or remote servers that we need to tell about
             # changes.
-            if any(row.hosts_calculated for row in rows):
-                hosts = await self.store.get_destinations_for_device(token)
-                await self.federation_sender.send_device_messages(
-                    hosts, immediate=False
-                )
+            hosts = {
+                row.entity
+                for row in rows
+                if not row.entity.startswith("@") and not row.is_signature
+            }
+            await self.federation_sender.send_device_messages(hosts, immediate=False)
 
         elif stream_name == ToDeviceStream.NAME:
             # The to_device stream includes stuff to be pushed to both local
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index d021904de7..661206c841 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -549,14 +549,10 @@ class DeviceListsStream(_StreamFromIdGen):
 
     @attr.s(slots=True, frozen=True, auto_attribs=True)
     class DeviceListsStreamRow:
-        user_id: str
+        entity: str
         # Indicates that a user has signed their own device with their user-signing key
         is_signature: bool
 
-        # Indicates if this is a notification that we've calculated the hosts we
-        # need to send the update to.
-        hosts_calculated: bool
-
     NAME = "device_lists"
     ROW_TYPE = DeviceListsStreamRow
 
@@ -598,13 +594,13 @@ class DeviceListsStream(_StreamFromIdGen):
             upper_limit_token = min(upper_limit_token, signatures_to_token)
 
         device_updates = [
-            (stream_id, (entity, False, hosts))
-            for stream_id, (entity, hosts) in device_updates
+            (stream_id, (entity, False))
+            for stream_id, (entity,) in device_updates
             if stream_id <= upper_limit_token
         ]
 
         signatures_updates = [
-            (stream_id, (entity, True, False))
+            (stream_id, (entity, True))
             for stream_id, (entity,) in signatures_updates
             if stream_id <= upper_limit_token
         ]
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 5eeca6165d..40187496e2 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -164,24 +164,22 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             prefilled_cache=user_signature_stream_prefill,
         )
 
-        self._device_list_federation_stream_cache = None
-        if hs.should_send_federation():
-            (
-                device_list_federation_prefill,
-                device_list_federation_list_id,
-            ) = self.db_pool.get_cache_dict(
-                db_conn,
-                "device_lists_outbound_pokes",
-                entity_column="destination",
-                stream_column="stream_id",
-                max_value=device_list_max,
-                limit=10000,
-            )
-            self._device_list_federation_stream_cache = StreamChangeCache(
-                "DeviceListFederationStreamChangeCache",
-                device_list_federation_list_id,
-                prefilled_cache=device_list_federation_prefill,
-            )
+        (
+            device_list_federation_prefill,
+            device_list_federation_list_id,
+        ) = self.db_pool.get_cache_dict(
+            db_conn,
+            "device_lists_outbound_pokes",
+            entity_column="destination",
+            stream_column="stream_id",
+            max_value=device_list_max,
+            limit=10000,
+        )
+        self._device_list_federation_stream_cache = StreamChangeCache(
+            "DeviceListFederationStreamChangeCache",
+            device_list_federation_list_id,
+            prefilled_cache=device_list_federation_prefill,
+        )
 
         if hs.config.worker.run_background_tasks:
             self._clock.looping_call(
@@ -209,29 +207,22 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     ) -> None:
         for row in rows:
             if row.is_signature:
-                self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
+                self._user_signature_stream_cache.entity_has_changed(row.entity, token)
                 continue
 
             # The entities are either user IDs (starting with '@') whose devices
             # have changed, or remote servers that we need to tell about
             # changes.
-            if not row.hosts_calculated:
-                self._device_list_stream_cache.entity_has_changed(row.user_id, token)
-                self.get_cached_devices_for_user.invalidate((row.user_id,))
-                self._get_cached_user_device.invalidate((row.user_id,))
-                self.get_device_list_last_stream_id_for_remote.invalidate(
-                    (row.user_id,)
-                )
+            if row.entity.startswith("@"):
+                self._device_list_stream_cache.entity_has_changed(row.entity, token)
+                self.get_cached_devices_for_user.invalidate((row.entity,))
+                self._get_cached_user_device.invalidate((row.entity,))
+                self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
 
-    def device_lists_outbound_pokes_have_changed(
-        self, destinations: StrCollection, token: int
-    ) -> None:
-        assert self._device_list_federation_stream_cache is not None
-
-        for destination in destinations:
-            self._device_list_federation_stream_cache.entity_has_changed(
-                destination, token
-            )
+            else:
+                self._device_list_federation_stream_cache.entity_has_changed(
+                    row.entity, token
+                )
 
     def device_lists_in_rooms_have_changed(
         self, room_ids: StrCollection, token: int
@@ -372,11 +363,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
               EDU contents.
         """
         now_stream_id = self.get_device_stream_token()
-        if from_stream_id == now_stream_id:
-            return now_stream_id, []
-
-        if self._device_list_federation_stream_cache is None:
-            raise Exception("Func can only be used on federation senders")
 
         has_changed = self._device_list_federation_stream_cache.has_entity_changed(
             destination, int(from_stream_id)
@@ -1032,10 +1018,10 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             # This query Does The Right Thing where it'll correctly apply the
             # bounds to the inner queries.
             sql = """
-                SELECT stream_id, user_id, hosts FROM (
-                    SELECT stream_id, user_id, false AS hosts FROM device_lists_stream
+                SELECT stream_id, entity FROM (
+                    SELECT stream_id, user_id AS entity FROM device_lists_stream
                     UNION ALL
-                    SELECT DISTINCT stream_id, user_id, true AS hosts FROM device_lists_outbound_pokes
+                    SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
                 ) AS e
                 WHERE ? < stream_id AND stream_id <= ?
                 ORDER BY stream_id ASC
@@ -1591,14 +1577,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             get_device_list_changes_in_room_txn,
         )
 
-    async def get_destinations_for_device(self, stream_id: int) -> StrCollection:
-        return await self.db_pool.simple_select_onecol(
-            table="device_lists_outbound_pokes",
-            keyvalues={"stream_id": stream_id},
-            retcol="destination",
-            desc="get_destinations_for_device",
-        )
-
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
     def __init__(
@@ -2134,13 +2112,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         stream_ids: List[int],
         context: Optional[Dict[str, str]],
     ) -> None:
-        if self._device_list_federation_stream_cache:
-            for host in hosts:
-                txn.call_after(
-                    self._device_list_federation_stream_cache.entity_has_changed,
-                    host,
-                    stream_ids[-1],
-                )
+        for host in hosts:
+            txn.call_after(
+                self._device_list_federation_stream_cache.entity_has_changed,
+                host,
+                stream_ids[-1],
+            )
 
         now = self._clock.time_msec()
         stream_id_iterator = iter(stream_ids)
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 9e6c9561ae..38d8785faa 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -123,9 +123,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         if stream_name == DeviceListsStream.NAME:
             for row in rows:
                 assert isinstance(row, DeviceListsStream.DeviceListsStreamRow)
-                if not row.hosts_calculated:
+                if row.entity.startswith("@"):
                     self._get_e2e_device_keys_for_federation_query_inner.invalidate(
-                        (row.user_id,)
+                        (row.entity,)
                     )
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index ba01b038ab..7f975d04ff 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -36,14 +36,6 @@ class DeviceStoreTestCase(HomeserverTestCase):
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
 
-    def default_config(self) -> JsonDict:
-        config = super().default_config()
-
-        # We 'enable' federation otherwise `get_device_updates_by_remote` will
-        # throw an exception.
-        config["federation_sender_instances"] = ["master"]
-        return config
-
     def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None:
         """Add a device list change for the given device to
         `device_lists_outbound_pokes` table.