summary refs log tree commit diff
diff options
context:
space:
mode:
authorKegan Dougal <7190048+kegsay@users.noreply.github.com>2024-01-30 16:42:11 +0000
committerKegan Dougal <7190048+kegsay@users.noreply.github.com>2024-01-30 16:42:11 +0000
commit38b304eaa73d6386845b80411c562d26e4799061 (patch)
tree5ca7e57e4859fe5f98920cc405ac8e49e93ef5b3
parentBump gitpython from 3.1.40 to 3.1.41 (#16850) (diff)
downloadsynapse-38b304eaa73d6386845b80411c562d26e4799061.tar.xz
Send device list update EDUs on room join
Fixes https://github.com/element-hq/synapse/issues/11374
Tested in https://github.com/matrix-org/complement/pull/704
-rw-r--r--synapse/handlers/device.py146
1 files changed, 144 insertions, 2 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 9062fac91a..d5a8bc0126 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -23,7 +23,7 @@ import logging
 from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Set, Tuple
 
 from synapse.api import errors
-from synapse.api.constants import EduTypes, EventTypes
+from synapse.api.constants import EduTypes, EventTypes, Membership
 from synapse.api.errors import (
     Codes,
     FederationDeniedError,
@@ -33,11 +33,13 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.logging.opentracing import log_kv, set_tag, trace
+import synapse.metrics
 from synapse.metrics.background_process_metrics import (
     run_as_background_process,
     wrap_as_background_process,
 )
 from synapse.storage.databases.main.client_ips import DeviceLastConnectionInfo
+from synapse.storage.databases.main.state_deltas import StateDelta
 from synapse.types import (
     JsonDict,
     JsonMapping,
@@ -54,7 +56,7 @@ from synapse.util import stringutils
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.cancellation import cancellable
-from synapse.util.metrics import measure_func
+from synapse.util.metrics import measure_func, Measure
 from synapse.util.retryutils import (
     NotRetryingDestination,
     filter_destinations_by_retry_limiter,
@@ -428,6 +430,7 @@ class DeviceHandler(DeviceWorkerHandler):
         self._account_data_handler = hs.get_account_data_handler()
         self._storage_controllers = hs.get_storage_controllers()
         self.db_pool = hs.get_datastores().main.db_pool
+        self._is_processing = False
 
         self.device_list_updater = DeviceListUpdater(hs, self)
 
@@ -461,6 +464,145 @@ class DeviceHandler(DeviceWorkerHandler):
                 self._delete_stale_devices,
             )
 
+        # Listen for state delta updates. We do this so we can send device list updates on room join
+        # to remote servers. We do not remember where we got up to before, as we only need to send
+        # these updates on a best-effort basis, as they quickly heal due to /keys/query requests.
+        # We want to send device list updates eagerly to improve our robustness on unreliable
+        # networks.
+        # See https://github.com/element-hq/synapse/issues/11374#issuecomment-1908396300
+        self._event_pos = self.store.get_room_max_stream_ordering()
+        self._event_processing = False
+        self.notifier.add_replication_callback(self.notify_new_event)
+
+    def notify_new_event(self) -> None:
+        """Called when there may be more deltas to process"""
+        if self._event_processing:
+            return
+
+        self._event_processing = True
+
+        async def process() -> None:
+            try:
+                await self._unsafe_process()
+            finally:
+                self._event_processing = False
+
+        run_as_background_process("device.notify_new_event", process)
+
+    async def _unsafe_process(self) -> None:
+        # Loop round handling deltas until we're up to date
+        while True:
+            with Measure(self.clock, "device_list_delta"):
+                room_max_stream_ordering = self.store.get_room_max_stream_ordering()
+                if self._event_pos == room_max_stream_ordering:
+                    return
+
+                logger.debug(
+                    "Processing device list stats %s->%s",
+                    self._event_pos,
+                    room_max_stream_ordering,
+                )
+                (
+                    max_pos,
+                    deltas,
+                ) = await self._storage_controllers.state.get_current_state_deltas(
+                    self._event_pos, room_max_stream_ordering
+                )
+
+                # We may get multiple deltas for different rooms, but we want to
+                # handle them on a room by room basis, so we batch them up by
+                # room.
+                deltas_by_room: Dict[str, List[StateDelta]] = {}
+                for delta in deltas:
+                    deltas_by_room.setdefault(delta.room_id, []).append(delta)
+
+                for room_id, deltas_for_room in deltas_by_room.items():
+                    newly_joined_local_users = await self._get_newly_joined_local_users(room_id, deltas_for_room)
+                    if not newly_joined_local_users:
+                        continue
+                    # if a local user newly joins a room, we want to broadcast their device lists to
+                    # federated servers in that room, if we haven't already.
+                    hosts = await self.store.get_current_hosts_in_room(room_id)
+                    # filter out ourselves
+                    hosts = [h for h in hosts if not self.hs.is_mine_server_name(h)]
+                    if len(hosts) == 0:
+                        continue
+                    # broadcast device lists for these users in the room
+                    num_pokes = 0
+                    for user_id in newly_joined_local_users:
+                        # the join is for the user, we need to send device list updates for all
+                        # their devices.
+                        device_ids = await self.store.get_devices_by_user(user_id)
+                        for device_id in device_ids.keys():
+                            num_pokes += 1
+                            await self.store.add_device_list_outbound_pokes(
+                                user_id=user_id,
+                                device_id=device_id,
+                                room_id=room_id,
+                                hosts=hosts,
+                                context=None,
+                            )
+                    logger.info(
+                        "Found %d hosts to send device list updates to for a new room join, " +
+                        "added %s device_list_outbound_pokes",
+                        len(hosts), num_pokes,
+                    )
+
+                    # Notify things that device lists need to be sent out.
+                    self.notifier.notify_replication()
+                    await self.federation_sender.send_device_messages(
+                        hosts, immediate=False
+                    )
+
+                self._event_pos = max_pos
+
+                # Expose current event processing position to prometheus
+                synapse.metrics.event_processing_positions.labels("device").set(
+                    max_pos
+                )
+
+    async def _get_newly_joined_local_users(self, room_id: str, deltas: List[StateDelta]) -> Optional[Set[str]]:
+        """Process current state deltas for the room to find new joins that need
+        to be handled.
+        """
+        newly_joined_local_users = set()
+
+        for delta in deltas:
+            assert room_id == delta.room_id
+            logger.debug(
+                "device.handling: %r %r, %s", delta.event_type, delta.state_key, delta.event_id
+            )
+            # Drop any event that isn't a membership join
+            if delta.event_type != EventTypes.Member:
+                continue
+            if delta.event_id is None:
+                # state has been deleted, so this is not a join. We only care about joins.
+                continue
+            # Drop any event that is for a non-local user
+            membership_change_user = UserID.from_string(delta.state_key)
+            if not self.hs.is_mine(membership_change_user):
+                continue
+            event = await self.store.get_event(delta.event_id, allow_none=True)
+            if not event or event.content.get("membership") != Membership.JOIN:
+                # We only care about joins
+                continue
+            if delta.prev_event_id:
+                prev_event = await self.store.get_event(
+                    delta.prev_event_id, allow_none=True
+                )
+                if (
+                    prev_event
+                    and prev_event.content.get("membership") == Membership.JOIN
+                ):
+                    # Ignore changes to join events.
+                    continue
+            newly_joined_local_users.add(delta.state_key)
+
+        if not newly_joined_local_users:
+            # If nobody has joined then there's nothing to do.
+            return
+        return newly_joined_local_users
+
     def _check_device_name_length(self, name: Optional[str]) -> None:
         """
         Checks whether a device name is longer than the maximum allowed length.