diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index d5ccaa0c37..c710c02cf9 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -37,7 +37,10 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.logging.opentracing import log_kv, set_tag, trace
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+ run_as_background_process,
+ wrap_as_background_process,
+)
from synapse.types import (
JsonDict,
StreamToken,
@@ -278,6 +281,22 @@ class DeviceHandler(DeviceWorkerHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
+ # Whether `_handle_new_device_update_async` is currently processing.
+ self._handle_new_device_update_is_processing = False
+
+ # If a new device update may have happened while the loop was
+ # processing.
+ self._handle_new_device_update_new_data = False
+
+ # On start up check if there are any updates pending.
+ hs.get_reactor().callWhenRunning(self._handle_new_device_update_async)
+
+ # Used to decide if we calculate outbound pokes up front or not. By
+ # default we do to allow safely downgrading Synapse.
+ self.use_new_device_lists_changes_in_room = (
+ hs.config.server.use_new_device_lists_changes_in_room
+ )
+
def _check_device_name_length(self, name: Optional[str]) -> None:
"""
Checks whether a device name is longer than the maximum allowed length.
@@ -469,19 +488,26 @@ class DeviceHandler(DeviceWorkerHandler):
# No changes to notify about, so this is a no-op.
return
- users_who_share_room = await self.store.get_users_who_share_room_with_user(
- user_id
- )
+ room_ids = await self.store.get_rooms_for_user(user_id)
+
+ hosts: Optional[Set[str]] = None
+ if not self.use_new_device_lists_changes_in_room:
+ hosts = set()
- hosts: Set[str] = set()
- if self.hs.is_mine_id(user_id):
- hosts.update(get_domain_from_id(u) for u in users_who_share_room)
- hosts.discard(self.server_name)
+ if self.hs.is_mine_id(user_id):
+ for room_id in room_ids:
+ joined_users = await self.store.get_users_in_room(room_id)
+ hosts.update(get_domain_from_id(u) for u in joined_users)
- set_tag("target_hosts", hosts)
+ set_tag("target_hosts", hosts)
+
+ hosts.discard(self.server_name)
position = await self.store.add_device_change_to_streams(
- user_id, device_ids, list(hosts)
+ user_id,
+ device_ids,
+ hosts=hosts,
+ room_ids=room_ids,
)
if not position:
@@ -495,9 +521,12 @@ class DeviceHandler(DeviceWorkerHandler):
# specify the user ID too since the user should always get their own device list
# updates, even if they aren't in any rooms.
- users_to_notify = users_who_share_room.union({user_id})
+ self.notifier.on_new_event(
+ "device_list_key", position, users={user_id}, rooms=room_ids
+ )
- self.notifier.on_new_event("device_list_key", position, users=users_to_notify)
+ # We may need to do some processing asynchronously.
+ self._handle_new_device_update_async()
if hosts:
logger.info(
@@ -614,6 +643,85 @@ class DeviceHandler(DeviceWorkerHandler):
return {"success": True}
+ @wrap_as_background_process("_handle_new_device_update_async")
+ async def _handle_new_device_update_async(self) -> None:
+ """Called when we have a new local device list update that we need to
+ send out over federation.
+
+ This happens in the background so as not to block the original request
+ that generated the device update.
+ """
+ if self._handle_new_device_update_is_processing:
+ self._handle_new_device_update_new_data = True
+ return
+
+ self._handle_new_device_update_is_processing = True
+
+ # The stream ID we processed previous iteration (if any), and the set of
+ # hosts we've already poked about for this update. This is so that we
+ # don't poke the same remote server about the same update repeatedly.
+ current_stream_id = None
+ hosts_already_sent_to: Set[str] = set()
+
+ try:
+ while True:
+ self._handle_new_device_update_new_data = False
+ rows = await self.store.get_uncoverted_outbound_room_pokes()
+ if not rows:
+ # If the DB returned nothing then there is nothing left to
+ # do, *unless* a new device list update happened during the
+ # DB query.
+ if self._handle_new_device_update_new_data:
+ continue
+ else:
+ return
+
+ for user_id, device_id, room_id, stream_id, opentracing_context in rows:
+ joined_user_ids = await self.store.get_users_in_room(room_id)
+ hosts = {get_domain_from_id(u) for u in joined_user_ids}
+ hosts.discard(self.server_name)
+
+ # Check if we've already sent this update to some hosts
+ if current_stream_id == stream_id:
+ hosts -= hosts_already_sent_to
+
+ await self.store.add_device_list_outbound_pokes(
+ user_id=user_id,
+ device_id=device_id,
+ room_id=room_id,
+ stream_id=stream_id,
+ hosts=hosts,
+ context=opentracing_context,
+ )
+
+ # Notify replication that we've updated the device list stream.
+ self.notifier.notify_replication()
+
+ if hosts:
+ logger.info(
+ "Sending device list update notif for %r to: %r",
+ user_id,
+ hosts,
+ )
+ for host in hosts:
+ self.federation_sender.send_device_messages(
+ host, immediate=False
+ )
+ log_kv(
+ {"message": "sent device update to host", "host": host}
+ )
+
+ if current_stream_id != stream_id:
+ # Clear the set of hosts we've already sent to as we're
+ # processing a new update.
+ hosts_already_sent_to.clear()
+
+ hosts_already_sent_to.update(hosts)
+ current_stream_id = stream_id
+
+ finally:
+ self._handle_new_device_update_is_processing = False
+
def _update_device_from_client_ips(
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
|