summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13934.misc1
-rw-r--r--synapse/handlers/device.py84
-rw-r--r--synapse/storage/databases/main/devices.py55
-rw-r--r--synapse/storage/databases/main/room.py16
4 files changed, 141 insertions, 15 deletions
diff --git a/changelog.d/13934.misc b/changelog.d/13934.misc
new file mode 100644
index 0000000000..6610a9f567
--- /dev/null
+++ b/changelog.d/13934.misc
@@ -0,0 +1 @@
+Correctly handle sending local device list updates to remote servers during a partial join.
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index f2ef591103..03082fce42 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -762,10 +762,90 @@ class DeviceHandler(DeviceWorkerHandler):
         gone from partial to full state.
         """
 
-        # We defer to the device list updater implementation as we're on the
-        # right worker.
+        # We defer to the device list updater to handle pending remote device
+        # list updates.
         await self.device_list_updater.handle_room_un_partial_stated(room_id)
 
+        # Replay local updates.
+        (
+            join_event_id,
+            device_lists_stream_id,
+        ) = await self.store.get_join_event_id_and_device_lists_stream_id_for_partial_state(
+            room_id
+        )
+
+        # Get the local device list changes that have happened in the room since
+        # we started joining. If there are no updates there's nothing left to do.
+        changes = await self.store.get_device_list_changes_in_room(
+            room_id, device_lists_stream_id
+        )
+        local_changes = {(u, d) for u, d in changes if self.hs.is_mine_id(u)}
+        if not local_changes:
+            return
+
+        # Note: We have persisted the full state at this point, we just haven't
+        # cleared the `partial_room` flag.
+        join_state_ids = await self._state_storage.get_state_ids_for_event(
+            join_event_id, await_full_state=False
+        )
+        current_state_ids = await self.store.get_partial_current_state_ids(room_id)
+
+        # Now we need to work out all servers that might have been in the room
+        # at any point during our join.
+
+        # First we look for any membership states that have changed between the
+        # initial join and now...
+        all_keys = set(join_state_ids)
+        all_keys.update(current_state_ids)
+
+        potentially_changed_hosts = set()
+        for etype, state_key in all_keys:
+            if etype != EventTypes.Member:
+                continue
+
+            prev = join_state_ids.get((etype, state_key))
+            current = current_state_ids.get((etype, state_key))
+
+            if prev != current:
+                potentially_changed_hosts.add(get_domain_from_id(state_key))
+
+        # ... then we add all the hosts that are currently joined to the room...
+        current_hosts_in_room = await self.store.get_current_hosts_in_room(room_id)
+        potentially_changed_hosts.update(current_hosts_in_room)
+
+        # ... and finally we remove any hosts that we were told about, as we
+        # will have sent device list updates to those hosts when they happened.
+        known_hosts_at_join = await self.store.get_partial_state_servers_at_join(
+            room_id
+        )
+        potentially_changed_hosts.difference_update(known_hosts_at_join)
+
+        potentially_changed_hosts.discard(self.server_name)
+
+        if not potentially_changed_hosts:
+            # Nothing to do.
+            return
+
+        logger.info(
+            "Found %d changed hosts to send device list updates to",
+            len(potentially_changed_hosts),
+        )
+
+        for user_id, device_id in local_changes:
+            await self.store.add_device_list_outbound_pokes(
+                user_id=user_id,
+                device_id=device_id,
+                room_id=room_id,
+                stream_id=None,
+                hosts=potentially_changed_hosts,
+                context=None,
+            )
+
+        # Notify things that device lists need to be sent out.
+        self.notifier.notify_replication()
+        for host in potentially_changed_hosts:
+            self.federation_sender.send_device_messages(host, immediate=False)
+
 
 def _update_device_from_client_ips(
     device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 1e562d4a40..18358eca46 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1307,6 +1307,33 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
 
         return changes
 
+    async def get_device_list_changes_in_room(
+        self, room_id: str, min_stream_id: int
+    ) -> Collection[Tuple[str, str]]:
+        """Get all device list changes that happened in the room since the given
+        stream ID.
+
+        Returns:
+            Collection of user ID/device ID tuples of all devices that have
+            changed
+        """
+
+        sql = """
+            SELECT DISTINCT user_id, device_id FROM device_lists_changes_in_room
+            WHERE room_id = ? AND stream_id > ?
+        """
+
+        def get_device_list_changes_in_room_txn(
+            txn: LoggingTransaction,
+        ) -> Collection[Tuple[str, str]]:
+            txn.execute(sql, (room_id, min_stream_id))
+            return cast(Collection[Tuple[str, str]], txn.fetchall())
+
+        return await self.db_pool.runInteraction(
+            "get_device_list_changes_in_room",
+            get_device_list_changes_in_room_txn,
+        )
+
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
     def __init__(
@@ -1946,14 +1973,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         user_id: str,
         device_id: str,
         room_id: str,
-        stream_id: int,
+        stream_id: Optional[int],
         hosts: Collection[str],
         context: Optional[Dict[str, str]],
     ) -> None:
         """Queue the device update to be sent to the given set of hosts,
         calculated from the room ID.
 
-        Marks the associated row in `device_lists_changes_in_room` as handled.
+        Marks the associated row in `device_lists_changes_in_room` as handled,
+        if `stream_id` is provided.
         """
 
         def add_device_list_outbound_pokes_txn(
@@ -1969,17 +1997,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                     context=context,
                 )
 
-            self.db_pool.simple_update_txn(
-                txn,
-                table="device_lists_changes_in_room",
-                keyvalues={
-                    "user_id": user_id,
-                    "device_id": device_id,
-                    "stream_id": stream_id,
-                    "room_id": room_id,
-                },
-                updatevalues={"converted_to_destinations": True},
-            )
+            if stream_id:
+                self.db_pool.simple_update_txn(
+                    txn,
+                    table="device_lists_changes_in_room",
+                    keyvalues={
+                        "user_id": user_id,
+                        "device_id": device_id,
+                        "stream_id": stream_id,
+                        "room_id": room_id,
+                    },
+                    updatevalues={"converted_to_destinations": True},
+                )
 
         if not hosts:
             # If there are no hosts then we don't try and generate stream IDs.
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 672c9a03fc..059eef5c22 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1256,6 +1256,22 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
 
         return entry is not None
 
+    async def get_join_event_id_and_device_lists_stream_id_for_partial_state(
+        self, room_id: str
+    ) -> Tuple[str, int]:
+        """Get the event ID of the initial join that started the partial
+        join, and the device list stream ID at the point we started the partial
+        join.
+        """
+
+        result = await self.db_pool.simple_select_one(
+            table="partial_state_rooms",
+            keyvalues={"room_id": room_id},
+            retcols=("join_event_id", "device_lists_stream_id"),
+            desc="get_join_event_id_for_partial_state",
+        )
+        return result["join_event_id"], result["device_lists_stream_id"]
+
 
 class _BackgroundUpdates:
     REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"