summary refs log tree commit diff
path: root/synapse/handlers/device.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/device.py')
-rw-r--r--synapse/handlers/device.py134
1 files changed, 121 insertions, 13 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index d5ccaa0c37..ffa28b2a30 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -37,7 +37,10 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.logging.opentracing import log_kv, set_tag, trace
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+    run_as_background_process,
+    wrap_as_background_process,
+)
 from synapse.types import (
     JsonDict,
     StreamToken,
@@ -278,6 +281,22 @@ class DeviceHandler(DeviceWorkerHandler):
 
         hs.get_distributor().observe("user_left_room", self.user_left_room)
 
+        # Whether `_handle_new_device_update_async` is currently processing.
+        self._handle_new_device_update_is_processing = False
+
+        # If a new device update may have happened while the loop was
+        # processing.
+        self._handle_new_device_update_new_data = False
+
+        # On start up check if there are any updates pending.
+        hs.get_reactor().callWhenRunning(self._handle_new_device_update_async)
+
+        # Used to decide if we calculate outbound pokes up front or not. By
+        # default we do to allow safely downgrading Synapse.
+        self.use_new_device_lists_changes_in_room = (
+            hs.config.server.use_new_device_lists_changes_in_room
+        )
+
     def _check_device_name_length(self, name: Optional[str]) -> None:
         """
         Checks whether a device name is longer than the maximum allowed length.
@@ -469,19 +488,26 @@ class DeviceHandler(DeviceWorkerHandler):
             # No changes to notify about, so this is a no-op.
             return
 
-        users_who_share_room = await self.store.get_users_who_share_room_with_user(
-            user_id
-        )
+        room_ids = await self.store.get_rooms_for_user(user_id)
+
+        hosts: Optional[Set[str]] = None
+        if not self.use_new_device_lists_changes_in_room:
+            hosts = set()
 
-        hosts: Set[str] = set()
-        if self.hs.is_mine_id(user_id):
-            hosts.update(get_domain_from_id(u) for u in users_who_share_room)
-            hosts.discard(self.server_name)
+            if self.hs.is_mine_id(user_id):
+                for room_id in room_ids:
+                    joined_users = await self.store.get_users_in_room(room_id)
+                    hosts.update(get_domain_from_id(u) for u in joined_users)
 
-        set_tag("target_hosts", hosts)
+                set_tag("target_hosts", hosts)
+
+                hosts.discard(self.server_name)
 
         position = await self.store.add_device_change_to_streams(
-            user_id, device_ids, list(hosts)
+            user_id,
+            device_ids,
+            hosts=hosts,
+            room_ids=room_ids,
         )
 
         if not position:
@@ -495,9 +521,12 @@ class DeviceHandler(DeviceWorkerHandler):
 
         # specify the user ID too since the user should always get their own device list
         # updates, even if they aren't in any rooms.
-        users_to_notify = users_who_share_room.union({user_id})
+        self.notifier.on_new_event(
+            "device_list_key", position, users={user_id}, rooms=room_ids
+        )
 
-        self.notifier.on_new_event("device_list_key", position, users=users_to_notify)
+        # We may need to do some processing asynchronously.
+        self._handle_new_device_update_async()
 
         if hosts:
             logger.info(
@@ -614,6 +643,85 @@ class DeviceHandler(DeviceWorkerHandler):
 
         return {"success": True}
 
+    @wrap_as_background_process("_handle_new_device_update_async")
+    async def _handle_new_device_update_async(self) -> None:
+        """Called when we have a new local device list update that we need to
+        send out over federation.
+
+        This happens in the background so as not to block the original request
+        that generated the device update.
+        """
+        if self._handle_new_device_update_is_processing:
+            self._handle_new_device_update_new_data = True
+            return
+
+        self._handle_new_device_update_is_processing = True
+
+        # The stream ID we processed previous iteration (if any), and the set of
+        # hosts we've already poked about for this update. This is so that we
+        # don't poke the same remote server about the same update repeatedly.
+        current_stream_id = None
+        hosts_already_sent_to: Set[str] = set()
+
+        try:
+            while True:
+                self._handle_new_device_update_new_data = False
+                rows = await self.store.get_uncoverted_outbound_room_pokes()
+                if not rows:
+                    # If the DB returned nothing then there is nothing left to
+                    # do, *unless* a new device list update happened during the
+                    # DB query.
+                    if self._handle_new_device_update_new_data:
+                        continue
+                    else:
+                        return
+
+                for user_id, device_id, room_id, stream_id, opentracing_context in rows:
+                    joined_user_ids = await self.store.get_users_in_room(room_id)
+                    hosts = {get_domain_from_id(u) for u in joined_user_ids}
+                    hosts.discard(self.server_name)
+
+                    # Check if we've already sent this update to some hosts
+                    if current_stream_id == stream_id:
+                        hosts -= hosts_already_sent_to
+
+                    await self.store.add_device_list_outbound_pokes(
+                        user_id=user_id,
+                        device_id=device_id,
+                        room_id=room_id,
+                        stream_id=stream_id,
+                        hosts=hosts,
+                        context=opentracing_context,
+                    )
+
+                    # Notify replication that we've updated the device list stream.
+                    self.notifier.notify_replication()
+
+                    if hosts:
+                        logger.info(
+                            "Sending device list update notif for %r to: %r",
+                            user_id,
+                            hosts,
+                        )
+                        for host in hosts:
+                            self.federation_sender.send_device_messages(
+                                host, immediate=False
+                            )
+                            log_kv(
+                                {"message": "sent device update to host", "host": host}
+                            )
+
+                    if current_stream_id != stream_id:
+                        # Clear the set of hosts we've already sent to as we're
+                        # processing a new update.
+                        hosts_already_sent_to.clear()
+
+                    hosts_already_sent_to.update(hosts)
+                    current_stream_id = stream_id
+
+        finally:
+            self._handle_new_device_update_is_processing = False
+
 
 def _update_device_from_client_ips(
     device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
@@ -725,7 +833,7 @@ class DeviceListUpdater:
     async def _handle_device_updates(self, user_id: str) -> None:
         "Actually handle pending updates."
 
-        with (await self._remote_edu_linearizer.queue(user_id)):
+        async with self._remote_edu_linearizer.queue(user_id):
             pending_updates = self._pending_updates.pop(user_id, [])
             if not pending_updates:
                 # This can happen since we batch updates