summary refs log tree commit diff
path: root/synapse/handlers/device.py
diff options
context:
space:
mode:
authorMathieu Velten <mathieuv@matrix.org>2023-10-09 11:20:08 +0200
committerMathieu Velten <mathieuv@matrix.org>2023-10-09 15:19:19 +0200
commit99fefd5501858ba90a10f35af757f424bc189f8c (patch)
treed86eb4f21a9d995bea57c79aabceeaf83f8c72ff /synapse/handlers/device.py
parentargggggghhhh (diff)
parentFix possible AttributeError when account-api is called over unix socket (#16404) (diff)
downloadsynapse-99fefd5501858ba90a10f35af757f424bc189f8c.tar.xz
Merge remote-tracking branch 'origin/develop' into anoa/public_rooms_module_api
Diffstat (limited to 'synapse/handlers/device.py')
-rw-r--r--synapse/handlers/device.py161
1 files changed, 133 insertions, 28 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 5d12a39e26..86ad96d030 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -43,9 +43,12 @@ from synapse.metrics.background_process_metrics import (
 )
 from synapse.types import (
     JsonDict,
+    JsonMapping,
+    ScheduledTask,
     StrCollection,
     StreamKeyType,
     StreamToken,
+    TaskStatus,
     UserID,
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
@@ -55,13 +58,17 @@ from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.cancellation import cancellable
 from synapse.util.metrics import measure_func
-from synapse.util.retryutils import NotRetryingDestination
+from synapse.util.retryutils import (
+    NotRetryingDestination,
+    filter_destinations_by_retry_limiter,
+)
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
+DELETE_DEVICE_MSGS_TASK_NAME = "delete_device_messages"
 MAX_DEVICE_DISPLAY_NAME_LEN = 100
 DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000
 
@@ -78,14 +85,20 @@ class DeviceWorkerHandler:
         self._appservice_handler = hs.get_application_service_handler()
         self._state_storage = hs.get_storage_controllers().state
         self._auth_handler = hs.get_auth_handler()
+        self._event_sources = hs.get_event_sources()
         self.server_name = hs.hostname
         self._msc3852_enabled = hs.config.experimental.msc3852_enabled
         self._query_appservices_for_keys = (
             hs.config.experimental.msc3984_appservice_key_query
         )
+        self._task_scheduler = hs.get_task_scheduler()
 
         self.device_list_updater = DeviceListWorkerUpdater(hs)
 
+        self._task_scheduler.register_action(
+            self._delete_device_messages, DELETE_DEVICE_MSGS_TASK_NAME
+        )
+
     @trace
     async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
         """
@@ -375,6 +388,33 @@ class DeviceWorkerHandler:
             "Trying handling device list state for partial join: not supported on workers."
         )
 
+    DEVICE_MSGS_DELETE_BATCH_LIMIT = 1000
+    DEVICE_MSGS_DELETE_SLEEP_MS = 1000
+
+    async def _delete_device_messages(
+        self,
+        task: ScheduledTask,
+    ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+        """Scheduler task to delete device messages in batch of `DEVICE_MSGS_DELETE_BATCH_LIMIT`."""
+        assert task.params is not None
+        user_id = task.params["user_id"]
+        device_id = task.params["device_id"]
+        up_to_stream_id = task.params["up_to_stream_id"]
+
+        # Delete the messages in batches to avoid too much DB load.
+        while True:
+            res = await self.store.delete_messages_for_device(
+                user_id=user_id,
+                device_id=device_id,
+                up_to_stream_id=up_to_stream_id,
+                limit=DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT,
+            )
+
+            if res < DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT:
+                return TaskStatus.COMPLETE, None, None
+
+            await self.clock.sleep(DeviceHandler.DEVICE_MSGS_DELETE_SLEEP_MS / 1000.0)
+
 
 class DeviceHandler(DeviceWorkerHandler):
     device_list_updater: "DeviceListUpdater"
@@ -385,6 +425,7 @@ class DeviceHandler(DeviceWorkerHandler):
         self.federation_sender = hs.get_federation_sender()
         self._account_data_handler = hs.get_account_data_handler()
         self._storage_controllers = hs.get_storage_controllers()
+        self.db_pool = hs.get_datastores().main.db_pool
 
         self.device_list_updater = DeviceListUpdater(hs, self)
 
@@ -529,6 +570,7 @@ class DeviceHandler(DeviceWorkerHandler):
             user_id: The user to delete devices from.
             device_ids: The list of device IDs to delete
         """
+        to_device_stream_id = self._event_sources.get_current_token().to_device_key
 
         try:
             await self.store.delete_devices(user_id, device_ids)
@@ -558,6 +600,17 @@ class DeviceHandler(DeviceWorkerHandler):
                     f"org.matrix.msc3890.local_notification_settings.{device_id}",
                 )
 
+            # Delete device messages asynchronously and in batches using the task scheduler
+            await self._task_scheduler.schedule_task(
+                DELETE_DEVICE_MSGS_TASK_NAME,
+                resource_id=device_id,
+                params={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "up_to_stream_id": to_device_stream_id,
+                },
+            )
+
         # Pushers are deleted after `delete_access_tokens_for_user` is called so that
         # modules using `on_logged_out` hook can use them if needed.
         await self.hs.get_pusherpool().remove_pushers_by_devices(user_id, device_ids)
@@ -653,29 +706,38 @@ class DeviceHandler(DeviceWorkerHandler):
     async def store_dehydrated_device(
         self,
         user_id: str,
+        device_id: Optional[str],
         device_data: JsonDict,
         initial_device_display_name: Optional[str] = None,
+        keys_for_device: Optional[JsonDict] = None,
     ) -> str:
-        """Store a dehydrated device for a user.  If the user had a previous
-        dehydrated device, it is removed.
+        """Store a dehydrated device for a user, optionally storing the keys associated with
+        it as well.  If the user had a previous dehydrated device, it is removed.
 
         Args:
             user_id: the user that we are storing the device for
+            device_id: device id supplied by client
             device_data: the dehydrated device information
             initial_device_display_name: The display name to use for the device
+            keys_for_device: keys for the dehydrated device
         Returns:
             device id of the dehydrated device
         """
         device_id = await self.check_device_registered(
             user_id,
-            None,
+            device_id,
             initial_device_display_name,
         )
+
+        time_now = self.clock.time_msec()
+
         old_device_id = await self.store.store_dehydrated_device(
-            user_id, device_id, device_data
+            user_id, device_id, device_data, time_now, keys_for_device
         )
+
         if old_device_id is not None:
             await self.delete_devices(user_id, [old_device_id])
+
         return device_id
 
     async def rehydrate_device(
@@ -697,12 +759,13 @@ class DeviceHandler(DeviceWorkerHandler):
 
         # If the dehydrated device was successfully deleted (the device ID
         # matched the stored dehydrated device), then modify the access
-        # token to use the dehydrated device's ID and copy the old device
-        # display name to the dehydrated device, and destroy the old device
-        # ID
+        # token and refresh token to use the dehydrated device's ID and
+        # copy the old device display name to the dehydrated device,
+        # and destroy the old device ID
         old_device_id = await self.store.set_device_for_access_token(
             access_token, device_id
         )
+        await self.store.set_device_for_refresh_token(user_id, old_device_id, device_id)
         old_device = await self.store.get_device(user_id, old_device_id)
         if old_device is None:
             raise errors.NotFoundError()
@@ -720,6 +783,22 @@ class DeviceHandler(DeviceWorkerHandler):
 
         return {"success": True}
 
+    async def delete_dehydrated_device(self, user_id: str, device_id: str) -> None:
+        """
+        Delete a stored dehydrated device.
+
+        Args:
+            user_id: the user_id to delete the device from
+            device_id: id of the dehydrated device to delete
+        """
+        success = await self.store.remove_dehydrated_device(user_id, device_id)
+
+        if not success:
+            raise errors.NotFoundError()
+
+        await self.delete_devices(user_id, [device_id])
+        await self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
+
     @wrap_as_background_process("_handle_new_device_update_async")
     async def _handle_new_device_update_async(self) -> None:
         """Called when we have a new local device list update that we need to
@@ -810,17 +889,16 @@ class DeviceHandler(DeviceWorkerHandler):
                             user_id,
                             hosts,
                         )
-                        for host in hosts:
-                            self.federation_sender.send_device_messages(
-                                host, immediate=False
-                            )
-                            # TODO: when called, this isn't in a logging context.
-                            # This leads to log spam, sentry event spam, and massive
-                            # memory usage.
-                            # See https://github.com/matrix-org/synapse/issues/12552.
-                            # log_kv(
-                            #     {"message": "sent device update to host", "host": host}
-                            # )
+                        await self.federation_sender.send_device_messages(
+                            hosts, immediate=False
+                        )
+                        # TODO: when called, this isn't in a logging context.
+                        # This leads to log spam, sentry event spam, and massive
+                        # memory usage.
+                        # See https://github.com/matrix-org/synapse/issues/12552.
+                        # log_kv(
+                        #     {"message": "sent device update to host", "host": host}
+                        # )
 
                     if current_stream_id != stream_id:
                         # Clear the set of hosts we've already sent to as we're
@@ -925,8 +1003,9 @@ class DeviceHandler(DeviceWorkerHandler):
 
         # Notify things that device lists need to be sent out.
         self.notifier.notify_replication()
-        for host in potentially_changed_hosts:
-            self.federation_sender.send_device_messages(host, immediate=False)
+        await self.federation_sender.send_device_messages(
+            potentially_changed_hosts, immediate=False
+        )
 
 
 def _update_device_from_client_ips(
@@ -956,7 +1035,7 @@ class DeviceListWorkerUpdater:
 
     async def multi_user_device_resync(
         self, user_ids: List[str], mark_failed_as_stale: bool = True
-    ) -> Dict[str, Optional[JsonDict]]:
+    ) -> Dict[str, Optional[JsonMapping]]:
         """
         Like `user_device_resync` but operates on multiple users **from the same origin**
         at once.
@@ -985,6 +1064,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
         self._notifier = hs.get_notifier()
 
         self._remote_edu_linearizer = Linearizer(name="remote_device_list")
+        self._resync_linearizer = Linearizer(name="remote_device_resync")
 
         # user_id -> list of updates waiting to be handled.
         self._pending_updates: Dict[
@@ -1124,7 +1204,14 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
                 )
 
             if resync:
-                await self.multi_user_device_resync([user_id])
+                # We mark as stale up front in case we get restarted.
+                await self.store.mark_remote_users_device_caches_as_stale([user_id])
+                run_as_background_process(
+                    "_maybe_retry_device_resync",
+                    self.multi_user_device_resync,
+                    [user_id],
+                    False,
+                )
             else:
                 # Simply update the single device, since we know that is the only
                 # change (because of the single prev_id matching the current cache)
@@ -1187,8 +1274,18 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
             self._resync_retry_in_progress = True
             # Get all of the users that need resyncing.
             need_resync = await self.store.get_user_ids_requiring_device_list_resync()
+
+            # Filter out users whose host is marked as "down" up front.
+            hosts = await filter_destinations_by_retry_limiter(
+                {get_domain_from_id(u) for u in need_resync}, self.clock, self.store
+            )
+            hosts = set(hosts)
+
             # Iterate over the set of user IDs.
             for user_id in need_resync:
+                if get_domain_from_id(user_id) not in hosts:
+                    continue
+
                 try:
                     # Try to resync the current user's devices list.
                     result = (await self.multi_user_device_resync([user_id], False))[
@@ -1220,7 +1317,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
 
     async def multi_user_device_resync(
         self, user_ids: List[str], mark_failed_as_stale: bool = True
-    ) -> Dict[str, Optional[JsonDict]]:
+    ) -> Dict[str, Optional[JsonMapping]]:
         """
         Like `user_device_resync` but operates on multiple users **from the same origin**
         at once.
@@ -1240,9 +1337,11 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
         failed = set()
         # TODO(Perf): Actually batch these up
         for user_id in user_ids:
-            user_result, user_failed = await self._user_device_resync_returning_failed(
-                user_id
-            )
+            async with self._resync_linearizer.queue(user_id):
+                (
+                    user_result,
+                    user_failed,
+                ) = await self._user_device_resync_returning_failed(user_id)
             result[user_id] = user_result
             if user_failed:
                 failed.add(user_id)
@@ -1254,7 +1353,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
 
     async def _user_device_resync_returning_failed(
         self, user_id: str
-    ) -> Tuple[Optional[JsonDict], bool]:
+    ) -> Tuple[Optional[JsonMapping], bool]:
         """Fetches all devices for a user and updates the device cache with them.
 
         Args:
@@ -1267,6 +1366,12 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
               e.g. due to a connection problem.
             - True iff the resync failed and the device list should be marked as stale.
         """
+        # Check that we haven't gone and fetched the devices since we last
+        # checked if we needed to resync these device lists.
+        if await self.store.get_users_whose_devices_are_cached([user_id]):
+            cached = await self.store.get_cached_devices_for_user(user_id)
+            return cached, False
+
         logger.debug("Attempting to resync the device list for %s", user_id)
         log_kv({"message": "Doing resync to update device list."})
         # Fetch all devices for the user.