summary refs log tree commit diff
path: root/synapse/handlers/device.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/device.py')
-rw-r--r--synapse/handlers/device.py84
1 files changed, 82 insertions, 2 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index f2ef591103..03082fce42 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -762,10 +762,90 @@ class DeviceHandler(DeviceWorkerHandler):
         gone from partial to full state.
         """
 
-        # We defer to the device list updater implementation as we're on the
-        # right worker.
+        # We defer to the device list updater to handle pending remote device
+        # list updates.
         await self.device_list_updater.handle_room_un_partial_stated(room_id)
 
+        # Replay local updates.
+        (
+            join_event_id,
+            device_lists_stream_id,
+        ) = await self.store.get_join_event_id_and_device_lists_stream_id_for_partial_state(
+            room_id
+        )
+
+        # Get the local device list changes that have happened in the room since
+        # we started joining. If there are no updates there's nothing left to do.
+        changes = await self.store.get_device_list_changes_in_room(
+            room_id, device_lists_stream_id
+        )
+        local_changes = {(u, d) for u, d in changes if self.hs.is_mine_id(u)}
+        if not local_changes:
+            return
+
+        # Note: We have persisted the full state at this point, we just haven't
+        # cleared the `partial_room` flag.
+        join_state_ids = await self._state_storage.get_state_ids_for_event(
+            join_event_id, await_full_state=False
+        )
+        current_state_ids = await self.store.get_partial_current_state_ids(room_id)
+
+        # Now we need to work out all servers that might have been in the room
+        # at any point during our join.
+
+        # First we look for any membership states that have changed between the
+        # initial join and now...
+        all_keys = set(join_state_ids)
+        all_keys.update(current_state_ids)
+
+        potentially_changed_hosts = set()
+        for etype, state_key in all_keys:
+            if etype != EventTypes.Member:
+                continue
+
+            prev = join_state_ids.get((etype, state_key))
+            current = current_state_ids.get((etype, state_key))
+
+            if prev != current:
+                potentially_changed_hosts.add(get_domain_from_id(state_key))
+
+        # ... then we add all the hosts that are currently joined to the room...
+        current_hosts_in_room = await self.store.get_current_hosts_in_room(room_id)
+        potentially_changed_hosts.update(current_hosts_in_room)
+
+        # ... and finally we remove any hosts that we were told about, as we
+        # will have sent device list updates to those hosts when they happened.
+        known_hosts_at_join = await self.store.get_partial_state_servers_at_join(
+            room_id
+        )
+        potentially_changed_hosts.difference_update(known_hosts_at_join)
+
+        potentially_changed_hosts.discard(self.server_name)
+
+        if not potentially_changed_hosts:
+            # Nothing to do.
+            return
+
+        logger.info(
+            "Found %d changed hosts to send device list updates to",
+            len(potentially_changed_hosts),
+        )
+
+        for user_id, device_id in local_changes:
+            await self.store.add_device_list_outbound_pokes(
+                user_id=user_id,
+                device_id=device_id,
+                room_id=room_id,
+                stream_id=None,
+                hosts=potentially_changed_hosts,
+                context=None,
+            )
+
+        # Notify things that device lists need to be sent out.
+        self.notifier.notify_replication()
+        for host in potentially_changed_hosts:
+            self.federation_sender.send_device_messages(host, immediate=False)
+
 
 def _update_device_from_client_ips(
     device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]