diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 3e2c3191c8..0df12e6380 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -35,6 +35,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.databases.main.devices import DeviceWorkerStore
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
@@ -536,6 +537,17 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
# account data entry to delete in the first place.
return False
+ # Record that this account data was deleted along with the devices that
+ # have yet to see it. Once all devices have later seen the delete, we can
+ # fully purge the row from `room_account_data`.
+ self._add_entries_to_account_data_undelivered_deletes_txn(
+ txn,
+ stream_id=next_id,
+ account_data_type=account_data_type,
+ room_id=room_id,
+ user_id=user_id,
+ )
+
return True
async with self._account_data_id_gen.get_next() as next_id:
@@ -706,6 +718,17 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
# account data entry to delete in the first place.
return False
+ # Record that this account data was deleted along with the devices that
+ # have yet to see it. Once all devices have later seen the delete, we can
+ # fully purge the row from `account_data`.
+ self._add_entries_to_account_data_undelivered_deletes_txn(
+ txn,
+ stream_id=next_id,
+ account_data_type=account_data_type,
+ room_id=None,
+ user_id=user_id,
+ )
+
# Ignored users get denormalized into a separate table as an optimisation.
if account_data_type == AccountDataTypes.IGNORED_USER_LIST:
# If this method was called with the ignored users account data type, we
@@ -757,6 +780,47 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
return self._account_data_id_gen.get_current_token()
+ def _add_entries_to_account_data_undelivered_deletes_txn(
+ self,
+ txn: LoggingTransaction,
+ stream_id: int,
+ account_data_type: str,
+ room_id: Optional[str],
+ user_id: str,
+ ) -> None:
+ """
+ Adds an entry per device of the given user to the
+ 'account_data_undelivered_deletes' table, which tracks the devices that have
+ yet to be informed of a deleted entry in either the 'account_data' or
+ 'room_account_data' tables.
+
+ Entries for hidden devices will not be created.
+
+ Args:
+ txn: The transaction that is handling the delete from (room)_account_data.
+ stream_id: The stream_id of the delete entry in the (room)_account_data table.
+ account_data_type: The type of {room,user} account data that was deleted.
+ room_id: The ID of the room if this refers to room account data, otherwise
+ this should be None.
+ user_id: The ID of the user this account data is related to.
+ """
+ # TODO: Is this too gross? :P
+ # Alternative is to duplicate the code in get_devices_by_user
+ user_device_dicts = DeviceWorkerStore.get_devices_by_user_txn(
+ txn, self.db_pool, user_id
+ )
+
+ # Create an entry in the deleted account data table for each device ID.
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="account_data_undelivered_deletes",
+ keys=("stream_id", "type", "room_id", "user_id", "device_id"),
+ values=(
+ (stream_id, account_data_type, room_id, user_id, device_id)
+ for device_id in user_device_dicts.keys()
+ ),
+ )
+
async def purge_account_data_for_user(self, user_id: str) -> None:
"""
Removes ALL the account data for a user.
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index a5bb4d404e..1a867d1f3e 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -276,11 +276,22 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
A mapping from device_id to a dict containing "device_id", "user_id"
and "display_name" for each device.
"""
- devices = await self.db_pool.simple_select_list(
+ return await self.db_pool.runInteraction(
+ "get_devices_by_user",
+ self.get_devices_by_user_txn,
+ self.db_pool,
+ user_id,
+ )
+
+ @staticmethod
+ def get_devices_by_user_txn(
+ txn: LoggingTransaction, db_pool: DatabasePool, user_id: str
+ ) -> Dict[str, Dict[str, str]]:
+ devices = db_pool.simple_select_list_txn(
+ txn,
table="devices",
keyvalues={"user_id": user_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
- desc="get_devices_by_user",
)
return {d["device_id"]: d for d in devices}
|