diff --git a/changelog.d/17333.misc b/changelog.d/17333.misc
new file mode 100644
index 0000000000..d3ef0b3777
--- /dev/null
+++ b/changelog.d/17333.misc
@@ -0,0 +1 @@
+Handle device lists notifications for large accounts more efficiently in worker mode.
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 2d6d49eed7..3dddbb70b4 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -114,13 +114,19 @@ class ReplicationDataHandler:
"""
all_room_ids: Set[str] = set()
if stream_name == DeviceListsStream.NAME:
- if any(row.entity.startswith("@") and not row.is_signature for row in rows):
+ if any(not row.is_signature and not row.hosts_calculated for row in rows):
prev_token = self.store.get_device_stream_token()
all_room_ids = await self.store.get_all_device_list_changes(
prev_token, token
)
self.store.device_lists_in_rooms_have_changed(all_room_ids, token)
+ # If we're sending federation we need to update the device lists
+ # outbound pokes stream change cache with updated hosts.
+ if self.send_handler and any(row.hosts_calculated for row in rows):
+ hosts = await self.store.get_destinations_for_device(token)
+ self.store.device_lists_outbound_pokes_have_changed(hosts, token)
+
self.store.process_replication_rows(stream_name, instance_name, token, rows)
# NOTE: this must be called after process_replication_rows to ensure any
# cache invalidations are first handled before any stream ID advances.
@@ -433,12 +439,11 @@ class FederationSenderHandler:
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
- hosts = {
- row.entity
- for row in rows
- if not row.entity.startswith("@") and not row.is_signature
- }
- await self.federation_sender.send_device_messages(hosts, immediate=False)
+ if any(row.hosts_calculated for row in rows):
+ hosts = await self.store.get_destinations_for_device(token)
+ await self.federation_sender.send_device_messages(
+ hosts, immediate=False
+ )
elif stream_name == ToDeviceStream.NAME:
# The to_device stream includes stuff to be pushed to both local
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 661206c841..d021904de7 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -549,10 +549,14 @@ class DeviceListsStream(_StreamFromIdGen):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceListsStreamRow:
- entity: str
+ user_id: str
# Indicates that a user has signed their own device with their user-signing key
is_signature: bool
+ # Indicates if this is a notification that we've calculated the hosts we
+ # need to send the update to.
+ hosts_calculated: bool
+
NAME = "device_lists"
ROW_TYPE = DeviceListsStreamRow
@@ -594,13 +598,13 @@ class DeviceListsStream(_StreamFromIdGen):
upper_limit_token = min(upper_limit_token, signatures_to_token)
device_updates = [
- (stream_id, (entity, False))
- for stream_id, (entity,) in device_updates
+ (stream_id, (entity, False, hosts))
+ for stream_id, (entity, hosts) in device_updates
if stream_id <= upper_limit_token
]
signatures_updates = [
- (stream_id, (entity, True))
+ (stream_id, (entity, True, False))
for stream_id, (entity,) in signatures_updates
if stream_id <= upper_limit_token
]
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 40187496e2..5eeca6165d 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -164,22 +164,24 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
prefilled_cache=user_signature_stream_prefill,
)
- (
- device_list_federation_prefill,
- device_list_federation_list_id,
- ) = self.db_pool.get_cache_dict(
- db_conn,
- "device_lists_outbound_pokes",
- entity_column="destination",
- stream_column="stream_id",
- max_value=device_list_max,
- limit=10000,
- )
- self._device_list_federation_stream_cache = StreamChangeCache(
- "DeviceListFederationStreamChangeCache",
- device_list_federation_list_id,
- prefilled_cache=device_list_federation_prefill,
- )
+ self._device_list_federation_stream_cache = None
+ if hs.should_send_federation():
+ (
+ device_list_federation_prefill,
+ device_list_federation_list_id,
+ ) = self.db_pool.get_cache_dict(
+ db_conn,
+ "device_lists_outbound_pokes",
+ entity_column="destination",
+ stream_column="stream_id",
+ max_value=device_list_max,
+ limit=10000,
+ )
+ self._device_list_federation_stream_cache = StreamChangeCache(
+ "DeviceListFederationStreamChangeCache",
+ device_list_federation_list_id,
+ prefilled_cache=device_list_federation_prefill,
+ )
if hs.config.worker.run_background_tasks:
self._clock.looping_call(
@@ -207,23 +209,30 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
) -> None:
for row in rows:
if row.is_signature:
- self._user_signature_stream_cache.entity_has_changed(row.entity, token)
+ self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
continue
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
- if row.entity.startswith("@"):
- self._device_list_stream_cache.entity_has_changed(row.entity, token)
- self.get_cached_devices_for_user.invalidate((row.entity,))
- self._get_cached_user_device.invalidate((row.entity,))
- self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
-
- else:
- self._device_list_federation_stream_cache.entity_has_changed(
- row.entity, token
+ if not row.hosts_calculated:
+ self._device_list_stream_cache.entity_has_changed(row.user_id, token)
+ self.get_cached_devices_for_user.invalidate((row.user_id,))
+ self._get_cached_user_device.invalidate((row.user_id,))
+ self.get_device_list_last_stream_id_for_remote.invalidate(
+ (row.user_id,)
)
+ def device_lists_outbound_pokes_have_changed(
+ self, destinations: StrCollection, token: int
+ ) -> None:
+ assert self._device_list_federation_stream_cache is not None
+
+ for destination in destinations:
+ self._device_list_federation_stream_cache.entity_has_changed(
+ destination, token
+ )
+
def device_lists_in_rooms_have_changed(
self, room_ids: StrCollection, token: int
) -> None:
@@ -363,6 +372,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
EDU contents.
"""
now_stream_id = self.get_device_stream_token()
+ if from_stream_id == now_stream_id:
+ return now_stream_id, []
+
+ if self._device_list_federation_stream_cache is None:
+ raise Exception("Func can only be used on federation senders")
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
@@ -1018,10 +1032,10 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
# This query Does The Right Thing where it'll correctly apply the
# bounds to the inner queries.
sql = """
- SELECT stream_id, entity FROM (
- SELECT stream_id, user_id AS entity FROM device_lists_stream
+ SELECT stream_id, user_id, hosts FROM (
+ SELECT stream_id, user_id, false AS hosts FROM device_lists_stream
UNION ALL
- SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
+ SELECT DISTINCT stream_id, user_id, true AS hosts FROM device_lists_outbound_pokes
) AS e
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
@@ -1577,6 +1591,14 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
get_device_list_changes_in_room_txn,
)
+ async def get_destinations_for_device(self, stream_id: int) -> StrCollection:
+ return await self.db_pool.simple_select_onecol(
+ table="device_lists_outbound_pokes",
+ keyvalues={"stream_id": stream_id},
+ retcol="destination",
+ desc="get_destinations_for_device",
+ )
+
class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(
@@ -2112,12 +2134,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
stream_ids: List[int],
context: Optional[Dict[str, str]],
) -> None:
- for host in hosts:
- txn.call_after(
- self._device_list_federation_stream_cache.entity_has_changed,
- host,
- stream_ids[-1],
- )
+ if self._device_list_federation_stream_cache:
+ for host in hosts:
+ txn.call_after(
+ self._device_list_federation_stream_cache.entity_has_changed,
+ host,
+ stream_ids[-1],
+ )
now = self._clock.time_msec()
stream_id_iterator = iter(stream_ids)
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 38d8785faa..9e6c9561ae 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -123,9 +123,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
if stream_name == DeviceListsStream.NAME:
for row in rows:
assert isinstance(row, DeviceListsStream.DeviceListsStreamRow)
- if row.entity.startswith("@"):
+ if not row.hosts_calculated:
self._get_e2e_device_keys_for_federation_query_inner.invalidate(
- (row.entity,)
+ (row.user_id,)
)
super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 7f975d04ff..ba01b038ab 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -36,6 +36,14 @@ class DeviceStoreTestCase(HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+
+ # We 'enable' federation otherwise `get_device_updates_by_remote` will
+ # throw an exception.
+ config["federation_sender_instances"] = ["master"]
+ return config
+
def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None:
"""Add a device list change for the given device to
`device_lists_outbound_pokes` table.
|