diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 763f56dfc1..86ad96d030 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -43,9 +43,12 @@ from synapse.metrics.background_process_metrics import (
)
from synapse.types import (
JsonDict,
+ JsonMapping,
+ ScheduledTask,
StrCollection,
StreamKeyType,
StreamToken,
+ TaskStatus,
UserID,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
@@ -55,13 +58,17 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.cancellation import cancellable
from synapse.util.metrics import measure_func
-from synapse.util.retryutils import NotRetryingDestination
+from synapse.util.retryutils import (
+ NotRetryingDestination,
+ filter_destinations_by_retry_limiter,
+)
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
+DELETE_DEVICE_MSGS_TASK_NAME = "delete_device_messages"
MAX_DEVICE_DISPLAY_NAME_LEN = 100
DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000
@@ -78,14 +85,20 @@ class DeviceWorkerHandler:
self._appservice_handler = hs.get_application_service_handler()
self._state_storage = hs.get_storage_controllers().state
self._auth_handler = hs.get_auth_handler()
+ self._event_sources = hs.get_event_sources()
self.server_name = hs.hostname
self._msc3852_enabled = hs.config.experimental.msc3852_enabled
self._query_appservices_for_keys = (
hs.config.experimental.msc3984_appservice_key_query
)
+ self._task_scheduler = hs.get_task_scheduler()
self.device_list_updater = DeviceListWorkerUpdater(hs)
+ self._task_scheduler.register_action(
+ self._delete_device_messages, DELETE_DEVICE_MSGS_TASK_NAME
+ )
+
@trace
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
"""
@@ -375,6 +388,33 @@ class DeviceWorkerHandler:
"Trying handling device list state for partial join: not supported on workers."
)
+ DEVICE_MSGS_DELETE_BATCH_LIMIT = 1000
+ DEVICE_MSGS_DELETE_SLEEP_MS = 1000
+
+ async def _delete_device_messages(
+ self,
+ task: ScheduledTask,
+ ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+ """Scheduler task to delete device messages in batch of `DEVICE_MSGS_DELETE_BATCH_LIMIT`."""
+ assert task.params is not None
+ user_id = task.params["user_id"]
+ device_id = task.params["device_id"]
+ up_to_stream_id = task.params["up_to_stream_id"]
+
+ # Delete the messages in batches to avoid too much DB load.
+ while True:
+ res = await self.store.delete_messages_for_device(
+ user_id=user_id,
+ device_id=device_id,
+ up_to_stream_id=up_to_stream_id,
+ limit=DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT,
+ )
+
+ if res < DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT:
+ return TaskStatus.COMPLETE, None, None
+
+ await self.clock.sleep(DeviceHandler.DEVICE_MSGS_DELETE_SLEEP_MS / 1000.0)
+
class DeviceHandler(DeviceWorkerHandler):
device_list_updater: "DeviceListUpdater"
@@ -530,6 +570,7 @@ class DeviceHandler(DeviceWorkerHandler):
user_id: The user to delete devices from.
device_ids: The list of device IDs to delete
"""
+ to_device_stream_id = self._event_sources.get_current_token().to_device_key
try:
await self.store.delete_devices(user_id, device_ids)
@@ -559,6 +600,17 @@ class DeviceHandler(DeviceWorkerHandler):
f"org.matrix.msc3890.local_notification_settings.{device_id}",
)
+ # Delete device messages asynchronously and in batches using the task scheduler
+ await self._task_scheduler.schedule_task(
+ DELETE_DEVICE_MSGS_TASK_NAME,
+ resource_id=device_id,
+ params={
+ "user_id": user_id,
+ "device_id": device_id,
+ "up_to_stream_id": to_device_stream_id,
+ },
+ )
+
# Pushers are deleted after `delete_access_tokens_for_user` is called so that
# modules using `on_logged_out` hook can use them if needed.
await self.hs.get_pusherpool().remove_pushers_by_devices(user_id, device_ids)
@@ -707,12 +759,13 @@ class DeviceHandler(DeviceWorkerHandler):
# If the dehydrated device was successfully deleted (the device ID
# matched the stored dehydrated device), then modify the access
- # token to use the dehydrated device's ID and copy the old device
- # display name to the dehydrated device, and destroy the old device
- # ID
+ # token and refresh token to use the dehydrated device's ID and
+ # copy the old device display name to the dehydrated device,
+ # and destroy the old device ID
old_device_id = await self.store.set_device_for_access_token(
access_token, device_id
)
+ await self.store.set_device_for_refresh_token(user_id, old_device_id, device_id)
old_device = await self.store.get_device(user_id, old_device_id)
if old_device is None:
raise errors.NotFoundError()
@@ -982,7 +1035,7 @@ class DeviceListWorkerUpdater:
async def multi_user_device_resync(
self, user_ids: List[str], mark_failed_as_stale: bool = True
- ) -> Dict[str, Optional[JsonDict]]:
+ ) -> Dict[str, Optional[JsonMapping]]:
"""
Like `user_device_resync` but operates on multiple users **from the same origin**
at once.
@@ -1011,6 +1064,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
self._notifier = hs.get_notifier()
self._remote_edu_linearizer = Linearizer(name="remote_device_list")
+ self._resync_linearizer = Linearizer(name="remote_device_resync")
# user_id -> list of updates waiting to be handled.
self._pending_updates: Dict[
@@ -1220,8 +1274,18 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
self._resync_retry_in_progress = True
# Get all of the users that need resyncing.
need_resync = await self.store.get_user_ids_requiring_device_list_resync()
+
+ # Filter out users whose host is marked as "down" up front.
+ hosts = await filter_destinations_by_retry_limiter(
+ {get_domain_from_id(u) for u in need_resync}, self.clock, self.store
+ )
+ hosts = set(hosts)
+
# Iterate over the set of user IDs.
for user_id in need_resync:
+ if get_domain_from_id(user_id) not in hosts:
+ continue
+
try:
# Try to resync the current user's devices list.
result = (await self.multi_user_device_resync([user_id], False))[
@@ -1253,7 +1317,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
async def multi_user_device_resync(
self, user_ids: List[str], mark_failed_as_stale: bool = True
- ) -> Dict[str, Optional[JsonDict]]:
+ ) -> Dict[str, Optional[JsonMapping]]:
"""
Like `user_device_resync` but operates on multiple users **from the same origin**
at once.
@@ -1273,9 +1337,11 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
failed = set()
# TODO(Perf): Actually batch these up
for user_id in user_ids:
- user_result, user_failed = await self._user_device_resync_returning_failed(
- user_id
- )
+ async with self._resync_linearizer.queue(user_id):
+ (
+ user_result,
+ user_failed,
+ ) = await self._user_device_resync_returning_failed(user_id)
result[user_id] = user_result
if user_failed:
failed.add(user_id)
@@ -1287,7 +1353,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
async def _user_device_resync_returning_failed(
self, user_id: str
- ) -> Tuple[Optional[JsonDict], bool]:
+ ) -> Tuple[Optional[JsonMapping], bool]:
"""Fetches all devices for a user and updates the device cache with them.
Args:
@@ -1300,6 +1366,12 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
e.g. due to a connection problem.
- True iff the resync failed and the device list should be marked as stale.
"""
+ # Check that we haven't gone and fetched the devices since we last
+ # checked if we needed to resync these device lists.
+ if await self.store.get_users_whose_devices_are_cached([user_id]):
+ cached = await self.store.get_cached_devices_for_user(user_id)
+ return cached, False
+
logger.debug("Attempting to resync the device list for %s", user_id)
log_kv({"message": "Doing resync to update device list."})
# Fetch all devices for the user.
|