summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/presence.py43
-rw-r--r--synapse/config/_base.py7
-rw-r--r--synapse/handlers/device.py48
-rw-r--r--synapse/handlers/initial_sync.py8
-rw-r--r--synapse/handlers/pagination.py12
-rw-r--r--synapse/handlers/presence.py300
-rw-r--r--synapse/handlers/room.py10
-rw-r--r--synapse/handlers/sync.py16
-rw-r--r--synapse/http/federation/matrix_federation_agent.py29
-rw-r--r--synapse/logging/context.py4
-rw-r--r--synapse/module_api/__init__.py13
-rw-r--r--synapse/module_api/callbacks/third_party_event_rules_callbacks.py11
-rw-r--r--synapse/storage/databases/main/deviceinbox.py28
-rw-r--r--synapse/storage/databases/main/devices.py8
-rw-r--r--synapse/storage/databases/main/keys.py17
-rw-r--r--synapse/storage/databases/main/receipts.py6
-rw-r--r--synapse/storage/engines/_base.py6
-rw-r--r--synapse/storage/engines/postgres.py4
-rw-r--r--synapse/storage/engines/sqlite.py4
-rw-r--r--synapse/storage/schema/main/delta/48/group_unique_indexes.py4
-rw-r--r--synapse/util/gai_resolver.py2
-rw-r--r--synapse/util/task_scheduler.py17
22 files changed, 479 insertions, 118 deletions
diff --git a/synapse/api/presence.py b/synapse/api/presence.py
index b80aa83cb3..b78f419994 100644
--- a/synapse/api/presence.py
+++ b/synapse/api/presence.py
@@ -20,18 +20,53 @@ from synapse.api.constants import PresenceState
 from synapse.types import JsonDict
 
 
+@attr.s(slots=True, auto_attribs=True)
+class UserDevicePresenceState:
+    """
+    Represents the current presence state of a user's device.
+
+    user_id: The user ID.
+    device_id: The user's device ID.
+    state: The presence state, see PresenceState.
+    last_active_ts: Time in msec that the device last interacted with server.
+    last_sync_ts: Time in msec that the device last *completed* a sync
+        (or event stream).
+    """
+
+    user_id: str
+    device_id: Optional[str]
+    state: str
+    last_active_ts: int
+    last_sync_ts: int
+
+    @classmethod
+    def default(
+        cls, user_id: str, device_id: Optional[str]
+    ) -> "UserDevicePresenceState":
+        """Returns a default presence state."""
+        return cls(
+            user_id=user_id,
+            device_id=device_id,
+            state=PresenceState.OFFLINE,
+            last_active_ts=0,
+            last_sync_ts=0,
+        )
+
+
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class UserPresenceState:
     """Represents the current presence state of the user.
 
-    user_id
-    last_active: Time in msec that the user last interacted with server.
-    last_federation_update: Time in msec since either a) we sent a presence
+    user_id: The user ID.
+    state: The presence state, see PresenceState.
+    last_active_ts: Time in msec that the user last interacted with server.
+    last_federation_update_ts: Time in msec since either a) we sent a presence
         update to other servers or b) we received a presence update, depending
         on if is a local user or not.
-    last_user_sync: Time in msec that the user last *completed* a sync
+    last_user_sync_ts: Time in msec that the user last *completed* a sync
         (or event stream).
     status_msg: User set status message.
+    currently_active: True if the user is currently syncing.
     """
 
     user_id: str
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 69a8318127..58856839e1 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -179,8 +179,9 @@ class Config:
 
         If an integer is provided it is treated as bytes and is unchanged.
 
-        String byte sizes can have a suffix of 'K' or `M`, representing kibibytes and
-        mebibytes respectively. No suffix is understood as a plain byte count.
+        String byte sizes can have a suffix of 'K', `M`, `G` or `T`,
+        representing kibibytes, mebibytes, gibibytes and tebibytes respectively.
+        No suffix is understood as a plain byte count.
 
         Raises:
             TypeError, if given something other than an integer or a string
@@ -189,7 +190,7 @@ class Config:
         if type(value) is int:  # noqa: E721
             return value
         elif isinstance(value, str):
-            sizes = {"K": 1024, "M": 1024 * 1024}
+            sizes = {"K": 1024, "M": 1024 * 1024, "G": 1024**3, "T": 1024**4}
             size = 1
             suffix = value[-1]
             if suffix in sizes:
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 763f56dfc1..9e52af5f13 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -43,9 +43,12 @@ from synapse.metrics.background_process_metrics import (
 )
 from synapse.types import (
     JsonDict,
+    JsonMapping,
+    ScheduledTask,
     StrCollection,
     StreamKeyType,
     StreamToken,
+    TaskStatus,
     UserID,
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
@@ -62,6 +65,7 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
+DELETE_DEVICE_MSGS_TASK_NAME = "delete_device_messages"
 MAX_DEVICE_DISPLAY_NAME_LEN = 100
 DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000
 
@@ -78,6 +82,7 @@ class DeviceWorkerHandler:
         self._appservice_handler = hs.get_application_service_handler()
         self._state_storage = hs.get_storage_controllers().state
         self._auth_handler = hs.get_auth_handler()
+        self._event_sources = hs.get_event_sources()
         self.server_name = hs.hostname
         self._msc3852_enabled = hs.config.experimental.msc3852_enabled
         self._query_appservices_for_keys = (
@@ -386,6 +391,7 @@ class DeviceHandler(DeviceWorkerHandler):
         self._account_data_handler = hs.get_account_data_handler()
         self._storage_controllers = hs.get_storage_controllers()
         self.db_pool = hs.get_datastores().main.db_pool
+        self._task_scheduler = hs.get_task_scheduler()
 
         self.device_list_updater = DeviceListUpdater(hs, self)
 
@@ -419,6 +425,10 @@ class DeviceHandler(DeviceWorkerHandler):
                 self._delete_stale_devices,
             )
 
+        self._task_scheduler.register_action(
+            self._delete_device_messages, DELETE_DEVICE_MSGS_TASK_NAME
+        )
+
     def _check_device_name_length(self, name: Optional[str]) -> None:
         """
         Checks whether a device name is longer than the maximum allowed length.
@@ -530,6 +540,7 @@ class DeviceHandler(DeviceWorkerHandler):
             user_id: The user to delete devices from.
             device_ids: The list of device IDs to delete
         """
+        to_device_stream_id = self._event_sources.get_current_token().to_device_key
 
         try:
             await self.store.delete_devices(user_id, device_ids)
@@ -559,12 +570,49 @@ class DeviceHandler(DeviceWorkerHandler):
                     f"org.matrix.msc3890.local_notification_settings.{device_id}",
                 )
 
+            # Delete device messages asynchronously and in batches using the task scheduler
+            await self._task_scheduler.schedule_task(
+                DELETE_DEVICE_MSGS_TASK_NAME,
+                resource_id=device_id,
+                params={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "up_to_stream_id": to_device_stream_id,
+                },
+            )
+
         # Pushers are deleted after `delete_access_tokens_for_user` is called so that
         # modules using `on_logged_out` hook can use them if needed.
         await self.hs.get_pusherpool().remove_pushers_by_devices(user_id, device_ids)
 
         await self.notify_device_update(user_id, device_ids)
 
+    DEVICE_MSGS_DELETE_BATCH_LIMIT = 100
+
+    async def _delete_device_messages(
+        self,
+        task: ScheduledTask,
+    ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+        """Scheduler task to delete device messages in batch of `DEVICE_MSGS_DELETE_BATCH_LIMIT`."""
+        assert task.params is not None
+        user_id = task.params["user_id"]
+        device_id = task.params["device_id"]
+        up_to_stream_id = task.params["up_to_stream_id"]
+
+        res = await self.store.delete_messages_for_device(
+            user_id=user_id,
+            device_id=device_id,
+            up_to_stream_id=up_to_stream_id,
+            limit=DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT,
+        )
+
+        if res < DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT:
+            return TaskStatus.COMPLETE, None, None
+        else:
+            # There is probably still device messages to be deleted, let's keep the task active and it will be run
+            # again in a subsequent scheduler loop run (probably the next one, if not too many tasks are running).
+            return TaskStatus.ACTIVE, None, None
+
     async def update_device(self, user_id: str, device_id: str, content: dict) -> None:
         """Update the given device
 
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index b3be7a86f0..5dc76ef588 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, List, Optional, Tuple
 
 from synapse.api.constants import (
     AccountDataTypes,
@@ -23,7 +23,6 @@ from synapse.api.constants import (
     Membership,
 )
 from synapse.api.errors import SynapseError
-from synapse.events import EventBase
 from synapse.events.utils import SerializeEventConfig
 from synapse.events.validator import EventValidator
 from synapse.handlers.presence import format_user_presence_state
@@ -35,7 +34,6 @@ from synapse.types import (
     JsonDict,
     Requester,
     RoomStreamToken,
-    StateMap,
     StreamKeyType,
     StreamToken,
     UserID,
@@ -199,9 +197,7 @@ class InitialSyncHandler:
                     deferred_room_state = run_in_background(
                         self._state_storage_controller.get_state_for_events,
                         [event.event_id],
-                    ).addCallback(
-                        lambda states: cast(StateMap[EventBase], states[event.event_id])
-                    )
+                    ).addCallback(lambda states: states[event.event_id])
 
                 (messages, token), current_state = await make_deferred_yieldable(
                     gather_results(
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index e5ac9096cc..19cf5a2b43 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -713,7 +713,7 @@ class PaginationHandler:
         self,
         delete_id: str,
         room_id: str,
-        requester_user_id: str,
+        requester_user_id: Optional[str],
         new_room_user_id: Optional[str] = None,
         new_room_name: Optional[str] = None,
         message: Optional[str] = None,
@@ -732,6 +732,10 @@ class PaginationHandler:
             requester_user_id:
                 User who requested the action. Will be recorded as putting the room on the
                 blocking list.
+                If None, the action was not manually requested but instead
+                triggered automatically, e.g. through a Synapse module
+                or some other policy.
+                MUST NOT be None if block=True.
             new_room_user_id:
                 If set, a new room will be created with this user ID
                 as the creator and admin, and all users in the old room will be
@@ -818,7 +822,7 @@ class PaginationHandler:
     def start_shutdown_and_purge_room(
         self,
         room_id: str,
-        requester_user_id: str,
+        requester_user_id: Optional[str],
         new_room_user_id: Optional[str] = None,
         new_room_name: Optional[str] = None,
         message: Optional[str] = None,
@@ -833,6 +837,10 @@ class PaginationHandler:
             requester_user_id:
                 User who requested the action and put the room on the
                 blocking list.
+                If None, the action was not manually requested but instead
+                triggered automatically, e.g. through a Synapse module
+                or some other policy.
+                MUST NOT be None if block=True.
             new_room_user_id:
                 If set, a new room will be created with this user ID
                 as the creator and admin, and all users in the old room will be
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index f31e18328b..375c7d0901 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -13,13 +13,56 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-"""This module is responsible for keeping track of presence status of local
+"""
+This module is responsible for keeping track of presence status of local
 and remote users.
 
 The methods that define policy are:
     - PresenceHandler._update_states
     - PresenceHandler._handle_timeouts
     - should_notify
+
+# Tracking local presence
+
+For local users, presence is tracked on a per-device basis. When a user has multiple
+devices the user presence state is derived by coalescing the presence from each
+device:
+
+    BUSY > ONLINE > UNAVAILABLE > OFFLINE
+
+The time that each device was last active and last synced is tracked in order to
+automatically downgrade a device's presence state:
+
+    A device may move from ONLINE -> UNAVAILABLE, if it has not been active for
+    a period of time.
+
+    A device may go from any state -> OFFLINE, if it is not active and has not
+    synced for a period of time.
+
+The timeouts are handled using a wheel timer, which has coarse buckets. Timings
+do not need to be exact.
+
+Generally a device's presence state is updated whenever a user syncs (via the
+set_presence parameter), when the presence API is called, or if "pro-active"
+events occur, including:
+
+* Sending an event, receipt, read marker.
+* Updating typing status.
+
+The busy state has special status that it cannot is not downgraded by a call to
+sync with a lower priority state *and* it takes a long period of time to transition
+to offline.
+
+# Persisting (and restoring) presence
+
+For all users, presence is persisted on a per-user basis. Data is kept in-memory
+and persisted periodically. When Synapse starts each worker loads the current
+presence state and then tracks the presence stream to keep itself up-to-date.
+
+When restoring presence for local users a pseudo-device is created to match the
+user state; this device follows the normal timeout logic (see above) and will
+automatically be replaced with any information from currently available devices.
+
 """
 import abc
 import contextlib
@@ -30,6 +73,7 @@ from contextlib import contextmanager
 from types import TracebackType
 from typing import (
     TYPE_CHECKING,
+    AbstractSet,
     Any,
     Callable,
     Collection,
@@ -49,7 +93,7 @@ from prometheus_client import Counter
 import synapse.metrics
 from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState
 from synapse.api.errors import SynapseError
-from synapse.api.presence import UserPresenceState
+from synapse.api.presence import UserDevicePresenceState, UserPresenceState
 from synapse.appservice import ApplicationService
 from synapse.events.presence_router import PresenceRouter
 from synapse.logging.context import run_in_background
@@ -111,6 +155,8 @@ LAST_ACTIVE_GRANULARITY = 60 * 1000
 # How long to wait until a new /events or /sync request before assuming
 # the client has gone.
 SYNC_ONLINE_TIMEOUT = 30 * 1000
+# Busy status waits longer, but does eventually go offline.
+BUSY_ONLINE_TIMEOUT = 60 * 60 * 1000
 
 # How long to wait before marking the user as idle. Compared against last active
 IDLE_TIMER = 5 * 60 * 1000
@@ -137,6 +183,7 @@ class BasePresenceHandler(abc.ABC):
     writer"""
 
     def __init__(self, hs: "HomeServer"):
+        self.hs = hs
         self.clock = hs.get_clock()
         self.store = hs.get_datastores().main
         self._storage_controllers = hs.get_storage_controllers()
@@ -162,6 +209,7 @@ class BasePresenceHandler(abc.ABC):
             self.VALID_PRESENCE += (PresenceState.BUSY,)
 
         active_presence = self.store.take_presence_startup_info()
+        # The combined status across all user devices.
         self.user_to_current_state = {state.user_id: state for state in active_presence}
 
     @abc.abstractmethod
@@ -426,8 +474,6 @@ class _NullContextManager(ContextManager[None]):
 class WorkerPresenceHandler(BasePresenceHandler):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
-        self.hs = hs
-
         self._presence_writer_instance = hs.config.worker.writers.presence[0]
 
         # Route presence EDUs to the right worker
@@ -691,7 +737,6 @@ class WorkerPresenceHandler(BasePresenceHandler):
 class PresenceHandler(BasePresenceHandler):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
-        self.hs = hs
         self.wheel_timer: WheelTimer[str] = WheelTimer()
         self.notifier = hs.get_notifier()
 
@@ -708,9 +753,27 @@ class PresenceHandler(BasePresenceHandler):
             lambda: len(self.user_to_current_state),
         )
 
+        # The per-device presence state, maps user to devices to per-device presence state.
+        self._user_to_device_to_current_state: Dict[
+            str, Dict[Optional[str], UserDevicePresenceState]
+        ] = {}
+
         now = self.clock.time_msec()
         if self._presence_enabled:
             for state in self.user_to_current_state.values():
+                # Create a psuedo-device to properly handle time outs. This will
+                # be overridden by any "real" devices within SYNC_ONLINE_TIMEOUT.
+                pseudo_device_id = None
+                self._user_to_device_to_current_state[state.user_id] = {
+                    pseudo_device_id: UserDevicePresenceState(
+                        user_id=state.user_id,
+                        device_id=pseudo_device_id,
+                        state=state.state,
+                        last_active_ts=state.last_active_ts,
+                        last_sync_ts=state.last_user_sync_ts,
+                    )
+                }
+
                 self.wheel_timer.insert(
                     now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER
                 )
@@ -752,7 +815,7 @@ class PresenceHandler(BasePresenceHandler):
 
         # Keeps track of the number of *ongoing* syncs on other processes.
         #
-        # While any sync is ongoing on another process the user will never
+        # While any sync is ongoing on another process the user's device will never
         # go offline.
         #
         # Each process has a unique identifier and an update frequency. If
@@ -981,22 +1044,21 @@ class PresenceHandler(BasePresenceHandler):
 
         timers_fired_counter.inc(len(states))
 
-        syncing_user_ids = {
-            user_id
-            for (user_id, _), count in self._user_device_to_num_current_syncs.items()
+        # Set of user ID & device IDs which are currently syncing.
+        syncing_user_devices = {
+            user_id_device_id
+            for user_id_device_id, count in self._user_device_to_num_current_syncs.items()
             if count
         }
-        syncing_user_ids.update(
-            user_id
-            for user_id, _ in itertools.chain(
-                *self.external_process_to_current_syncs.values()
-            )
+        syncing_user_devices.update(
+            itertools.chain(*self.external_process_to_current_syncs.values())
         )
 
         changes = handle_timeouts(
             states,
             is_mine_fn=self.is_mine_id,
-            syncing_user_ids=syncing_user_ids,
+            syncing_user_devices=syncing_user_devices,
+            user_to_devices=self._user_to_device_to_current_state,
             now=now,
         )
 
@@ -1016,11 +1078,26 @@ class PresenceHandler(BasePresenceHandler):
 
         bump_active_time_counter.inc()
 
-        prev_state = await self.current_state_for_user(user_id)
+        now = self.clock.time_msec()
 
-        new_fields: Dict[str, Any] = {"last_active_ts": self.clock.time_msec()}
-        if prev_state.state == PresenceState.UNAVAILABLE:
-            new_fields["state"] = PresenceState.ONLINE
+        # Update the device information & mark the device as online if it was
+        # unavailable.
+        devices = self._user_to_device_to_current_state.setdefault(user_id, {})
+        device_state = devices.setdefault(
+            device_id,
+            UserDevicePresenceState.default(user_id, device_id),
+        )
+        device_state.last_active_ts = now
+        if device_state.state == PresenceState.UNAVAILABLE:
+            device_state.state = PresenceState.ONLINE
+
+        # Update the user state, this will always update last_active_ts and
+        # might update the presence state.
+        prev_state = await self.current_state_for_user(user_id)
+        new_fields: Dict[str, Any] = {
+            "last_active_ts": now,
+            "state": _combine_device_states(devices.values()),
+        }
 
         await self._update_states([prev_state.copy_and_replace(**new_fields)])
 
@@ -1132,6 +1209,12 @@ class PresenceHandler(BasePresenceHandler):
             if is_syncing and (user_id, device_id) not in process_presence:
                 process_presence.add((user_id, device_id))
             elif not is_syncing and (user_id, device_id) in process_presence:
+                devices = self._user_to_device_to_current_state.setdefault(user_id, {})
+                device_state = devices.setdefault(
+                    device_id, UserDevicePresenceState.default(user_id, device_id)
+                )
+                device_state.last_sync_ts = sync_time_msec
+
                 new_state = prev_state.copy_and_replace(
                     last_user_sync_ts=sync_time_msec
                 )
@@ -1151,11 +1234,24 @@ class PresenceHandler(BasePresenceHandler):
             process_presence = self.external_process_to_current_syncs.pop(
                 process_id, set()
             )
-            prev_states = await self.current_state_for_users(
-                {user_id for user_id, device_id in process_presence}
-            )
+
             time_now_ms = self.clock.time_msec()
 
+            # Mark each device as having a last sync time.
+            updated_users = set()
+            for user_id, device_id in process_presence:
+                device_state = self._user_to_device_to_current_state.setdefault(
+                    user_id, {}
+                ).setdefault(
+                    device_id, UserDevicePresenceState.default(user_id, device_id)
+                )
+
+                device_state.last_sync_ts = time_now_ms
+                updated_users.add(user_id)
+
+            # Update each user (and insert into the appropriate timers to check if
+            # they've gone offline).
+            prev_states = await self.current_state_for_users(updated_users)
             await self._update_states(
                 [
                     prev_state.copy_and_replace(last_user_sync_ts=time_now_ms)
@@ -1277,6 +1373,20 @@ class PresenceHandler(BasePresenceHandler):
         if prev_state.state == PresenceState.BUSY and is_sync:
             presence = PresenceState.BUSY
 
+        # Update the device specific information.
+        devices = self._user_to_device_to_current_state.setdefault(user_id, {})
+        device_state = devices.setdefault(
+            device_id,
+            UserDevicePresenceState.default(user_id, device_id),
+        )
+        device_state.state = presence
+        device_state.last_active_ts = now
+        if is_sync:
+            device_state.last_sync_ts = now
+
+        # Based on the state of each user's device calculate the new presence state.
+        presence = _combine_device_states(devices.values())
+
         new_fields = {"state": presence}
 
         if presence == PresenceState.ONLINE or presence == PresenceState.BUSY:
@@ -1873,7 +1983,8 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
 def handle_timeouts(
     user_states: List[UserPresenceState],
     is_mine_fn: Callable[[str], bool],
-    syncing_user_ids: Set[str],
+    syncing_user_devices: AbstractSet[Tuple[str, Optional[str]]],
+    user_to_devices: Dict[str, Dict[Optional[str], UserDevicePresenceState]],
     now: int,
 ) -> List[UserPresenceState]:
     """Checks the presence of users that have timed out and updates as
@@ -1882,7 +1993,8 @@ def handle_timeouts(
     Args:
         user_states: List of UserPresenceState's to check.
         is_mine_fn: Function that returns if a user_id is ours
-        syncing_user_ids: Set of user_ids with active syncs.
+        syncing_user_devices: A set of (user ID, device ID) tuples with active syncs..
+        user_to_devices: A map of user ID to device ID to UserDevicePresenceState.
         now: Current time in ms.
 
     Returns:
@@ -1891,9 +2003,16 @@ def handle_timeouts(
     changes = {}  # Actual changes we need to notify people about
 
     for state in user_states:
-        is_mine = is_mine_fn(state.user_id)
-
-        new_state = handle_timeout(state, is_mine, syncing_user_ids, now)
+        user_id = state.user_id
+        is_mine = is_mine_fn(user_id)
+
+        new_state = handle_timeout(
+            state,
+            is_mine,
+            syncing_user_devices,
+            user_to_devices.get(user_id, {}),
+            now,
+        )
         if new_state:
             changes[state.user_id] = new_state
 
@@ -1901,14 +2020,19 @@ def handle_timeouts(
 
 
 def handle_timeout(
-    state: UserPresenceState, is_mine: bool, syncing_user_ids: Set[str], now: int
+    state: UserPresenceState,
+    is_mine: bool,
+    syncing_device_ids: AbstractSet[Tuple[str, Optional[str]]],
+    user_devices: Dict[Optional[str], UserDevicePresenceState],
+    now: int,
 ) -> Optional[UserPresenceState]:
     """Checks the presence of the user to see if any of the timers have elapsed
 
     Args:
-        state
+        state: UserPresenceState to check.
         is_mine: Whether the user is ours
-        syncing_user_ids: Set of user_ids with active syncs.
+        syncing_user_devices: A set of (user ID, device ID) tuples with active syncs..
+        user_devices: A map of device ID to UserDevicePresenceState.
         now: Current time in ms.
 
     Returns:
@@ -1919,34 +2043,63 @@ def handle_timeout(
         return None
 
     changed = False
-    user_id = state.user_id
 
     if is_mine:
-        if state.state == PresenceState.ONLINE:
-            if now - state.last_active_ts > IDLE_TIMER:
-                # Currently online, but last activity ages ago so auto
-                # idle
-                state = state.copy_and_replace(state=PresenceState.UNAVAILABLE)
-                changed = True
-            elif now - state.last_active_ts > LAST_ACTIVE_GRANULARITY:
-                # So that we send down a notification that we've
-                # stopped updating.
+        # Check per-device whether the device should be considered idle or offline
+        # due to timeouts.
+        device_changed = False
+        offline_devices = []
+        for device_id, device_state in user_devices.items():
+            if device_state.state == PresenceState.ONLINE:
+                if now - device_state.last_active_ts > IDLE_TIMER:
+                    # Currently online, but last activity ages ago so auto
+                    # idle
+                    device_state.state = PresenceState.UNAVAILABLE
+                    device_changed = True
+
+            # If there are have been no sync for a while (and none ongoing),
+            # set presence to offline.
+            if (state.user_id, device_id) not in syncing_device_ids:
+                # If the user has done something recently but hasn't synced,
+                # don't set them as offline.
+                sync_or_active = max(
+                    device_state.last_sync_ts, device_state.last_active_ts
+                )
+
+                # Implementations aren't meant to timeout a device with a busy
+                # state, but it needs to timeout *eventually* or else the user
+                # will be stuck in that state.
+                online_timeout = (
+                    BUSY_ONLINE_TIMEOUT
+                    if device_state.state == PresenceState.BUSY
+                    else SYNC_ONLINE_TIMEOUT
+                )
+                if now - sync_or_active > online_timeout:
+                    # Mark the device as going offline.
+                    offline_devices.append(device_id)
+                    device_changed = True
+
+        # Offline devices are not needed and do not add information.
+        for device_id in offline_devices:
+            user_devices.pop(device_id)
+
+        # If the presence state of the devices changed, then (maybe) update
+        # the user's overall presence state.
+        if device_changed:
+            new_presence = _combine_device_states(user_devices.values())
+            if new_presence != state.state:
+                state = state.copy_and_replace(state=new_presence)
                 changed = True
 
+        if now - state.last_active_ts > LAST_ACTIVE_GRANULARITY:
+            # So that we send down a notification that we've
+            # stopped updating.
+            changed = True
+
         if now - state.last_federation_update_ts > FEDERATION_PING_INTERVAL:
             # Need to send ping to other servers to ensure they don't
             # timeout and set us to offline
             changed = True
-
-        # If there are have been no sync for a while (and none ongoing),
-        # set presence to offline
-        if user_id not in syncing_user_ids:
-            # If the user has done something recently but hasn't synced,
-            # don't set them as offline.
-            sync_or_active = max(state.last_user_sync_ts, state.last_active_ts)
-            if now - sync_or_active > SYNC_ONLINE_TIMEOUT:
-                state = state.copy_and_replace(state=PresenceState.OFFLINE)
-                changed = True
     else:
         # We expect to be poked occasionally by the other side.
         # This is to protect against forgetful/buggy servers, so that
@@ -2021,6 +2174,13 @@ def handle_update(
                 new_state = new_state.copy_and_replace(last_federation_update_ts=now)
                 federation_ping = True
 
+        if new_state.state == PresenceState.BUSY:
+            wheel_timer.insert(
+                now=now,
+                obj=user_id,
+                then=new_state.last_user_sync_ts + BUSY_ONLINE_TIMEOUT,
+            )
+
     else:
         wheel_timer.insert(
             now=now,
@@ -2036,6 +2196,46 @@ def handle_update(
     return new_state, persist_and_notify, federation_ping
 
 
+PRESENCE_BY_PRIORITY = {
+    PresenceState.BUSY: 4,
+    PresenceState.ONLINE: 3,
+    PresenceState.UNAVAILABLE: 2,
+    PresenceState.OFFLINE: 1,
+}
+
+
+def _combine_device_states(
+    device_states: Iterable[UserDevicePresenceState],
+) -> str:
+    """
+    Find the device to use presence information from.
+
+    Orders devices by priority, then last_active_ts.
+
+    Args:
+        device_states: An iterable of device presence states
+
+    Return:
+        The combined presence state.
+    """
+
+    # Based on (all) the user's devices calculate the new presence state.
+    presence = PresenceState.OFFLINE
+    last_active_ts = -1
+
+    # Find the device to use the presence state of based on the presence priority,
+    # but tie-break with how recently the device has been seen.
+    for device_state in device_states:
+        if (PRESENCE_BY_PRIORITY[device_state.state], device_state.last_active_ts) > (
+            PRESENCE_BY_PRIORITY[presence],
+            last_active_ts,
+        ):
+            presence = device_state.state
+            last_active_ts = device_state.last_active_ts
+
+    return presence
+
+
 async def get_interested_parties(
     store: DataStore, presence_router: PresenceRouter, states: List[UserPresenceState]
 ) -> Tuple[Dict[str, List[UserPresenceState]], Dict[str, List[UserPresenceState]]]:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 0513e28aab..7a762c8511 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1787,7 +1787,7 @@ class RoomShutdownHandler:
     async def shutdown_room(
         self,
         room_id: str,
-        requester_user_id: str,
+        requester_user_id: Optional[str],
         new_room_user_id: Optional[str] = None,
         new_room_name: Optional[str] = None,
         message: Optional[str] = None,
@@ -1811,6 +1811,10 @@ class RoomShutdownHandler:
             requester_user_id:
                 User who requested the action and put the room on the
                 blocking list.
+                If None, the action was not manually requested but instead
+                triggered automatically, e.g. through a Synapse module
+                or some other policy.
+                MUST NOT be None if block=True.
             new_room_user_id:
                 If set, a new room will be created with this user ID
                 as the creator and admin, and all users in the old room will be
@@ -1863,6 +1867,10 @@ class RoomShutdownHandler:
 
         # Action the block first (even if the room doesn't exist yet)
         if block:
+            if requester_user_id is None:
+                raise ValueError(
+                    "shutdown_room: block=True not allowed when requester_user_id is None."
+                )
             # This will work even if the room is already blocked, but that is
             # desirable in case the first attempt at blocking the room failed below.
             await self.store.block_room(room_id, requester_user_id)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 60a9f341b5..0ccd7d250c 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -40,6 +40,7 @@ from synapse.api.filtering import FilterCollection
 from synapse.api.presence import UserPresenceState
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
+from synapse.handlers.device import DELETE_DEVICE_MSGS_TASK_NAME
 from synapse.handlers.relations import BundledAggregations
 from synapse.logging import issue9533_logger
 from synapse.logging.context import current_context
@@ -268,6 +269,7 @@ class SyncHandler:
         self._storage_controllers = hs.get_storage_controllers()
         self._state_storage_controller = self._storage_controllers.state
         self._device_handler = hs.get_device_handler()
+        self._task_scheduler = hs.get_task_scheduler()
 
         self.should_calculate_push_rules = hs.config.push.enable_push
 
@@ -360,11 +362,19 @@ class SyncHandler:
         # (since we now know that the device has received them)
         if since_token is not None:
             since_stream_id = since_token.to_device_key
-            deleted = await self.store.delete_messages_for_device(
-                sync_config.user.to_string(), sync_config.device_id, since_stream_id
+            # Delete device messages asynchronously and in batches using the task scheduler
+            await self._task_scheduler.schedule_task(
+                DELETE_DEVICE_MSGS_TASK_NAME,
+                resource_id=sync_config.device_id,
+                params={
+                    "user_id": sync_config.user.to_string(),
+                    "device_id": sync_config.device_id,
+                    "up_to_stream_id": since_stream_id,
+                },
             )
             logger.debug(
-                "Deleted %d to-device messages up to %d", deleted, since_stream_id
+                "Deletion of to-device messages up to %d scheduled",
+                since_stream_id,
             )
 
         if timeout == 0 or since_token is None or full_state:
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 91a24efcd0..a3a396bb37 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -399,15 +399,34 @@ class MatrixHostnameEndpoint:
         if port or _is_ip_literal(host):
             return [Server(host, port or 8448)]
 
+        # Check _matrix-fed._tcp SRV record.
         logger.debug("Looking up SRV record for %s", host.decode(errors="replace"))
+        server_list = await self._srv_resolver.resolve_service(
+            b"_matrix-fed._tcp." + host
+        )
+
+        if server_list:
+            if logger.isEnabledFor(logging.DEBUG):
+                logger.debug(
+                    "Got %s from SRV lookup for %s",
+                    ", ".join(map(str, server_list)),
+                    host.decode(errors="replace"),
+                )
+            return server_list
+
+        # No _matrix-fed._tcp SRV record, fallback to legacy _matrix._tcp SRV record.
+        logger.debug(
+            "Looking up deprecated SRV record for %s", host.decode(errors="replace")
+        )
         server_list = await self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
 
         if server_list:
-            logger.debug(
-                "Got %s from SRV lookup for %s",
-                ", ".join(map(str, server_list)),
-                host.decode(errors="replace"),
-            )
+            if logger.isEnabledFor(logging.DEBUG):
+                logger.debug(
+                    "Got %s from deprecated SRV lookup for %s",
+                    ", ".join(map(str, server_list)),
+                    host.decode(errors="replace"),
+                )
             return server_list
 
         # No SRV records, so we fallback to host and 8448
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 64c6ae4512..bf7e311026 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -728,7 +728,7 @@ async def _unwrap_awaitable(awaitable: Awaitable[R]) -> R:
 
 
 @overload
-def preserve_fn(  # type: ignore[misc]
+def preserve_fn(
     f: Callable[P, Awaitable[R]],
 ) -> Callable[P, "defer.Deferred[R]"]:
     # The `type: ignore[misc]` above suppresses
@@ -756,7 +756,7 @@ def preserve_fn(
 
 
 @overload
-def run_in_background(  # type: ignore[misc]
+def run_in_background(
     f: Callable[P, Awaitable[R]], *args: P.args, **kwargs: P.kwargs
 ) -> "defer.Deferred[R]":
     # The `type: ignore[misc]` above suppresses
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 2f00a7ba20..d6efe10a28 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -1730,6 +1730,19 @@ class ModuleApi:
         room_alias_str = room_alias.to_string() if room_alias else None
         return room_id, room_alias_str
 
+    async def delete_room(self, room_id: str) -> None:
+        """
+        Schedules the deletion of a room from Synapse's database.
+
+        If the room is already being deleted, this method does nothing.
+        This method does not wait for the room to be deleted.
+
+        Added in Synapse v1.89.0.
+        """
+        # Future extensions to this method might want to e.g. allow use of `force_purge`.
+        # TODO In the future we should make sure this is persistent.
+        self._hs.get_pagination_handler().start_shutdown_and_purge_room(room_id, None)
+
     async def set_displayname(
         self,
         user_id: UserID,
diff --git a/synapse/module_api/callbacks/third_party_event_rules_callbacks.py b/synapse/module_api/callbacks/third_party_event_rules_callbacks.py
index 911f37ba42..ecaeef3511 100644
--- a/synapse/module_api/callbacks/third_party_event_rules_callbacks.py
+++ b/synapse/module_api/callbacks/third_party_event_rules_callbacks.py
@@ -40,7 +40,7 @@ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[
     [str, StateMap[EventBase], str], Awaitable[bool]
 ]
 ON_NEW_EVENT_CALLBACK = Callable[[EventBase, StateMap[EventBase]], Awaitable]
-CHECK_CAN_SHUTDOWN_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]]
+CHECK_CAN_SHUTDOWN_ROOM_CALLBACK = Callable[[Optional[str], str], Awaitable[bool]]
 CHECK_CAN_DEACTIVATE_USER_CALLBACK = Callable[[str, bool], Awaitable[bool]]
 ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable]
 ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable]
@@ -429,12 +429,17 @@ class ThirdPartyEventRulesModuleApiCallbacks:
                     "Failed to run module API callback %s: %s", callback, e
                 )
 
-    async def check_can_shutdown_room(self, user_id: str, room_id: str) -> bool:
+    async def check_can_shutdown_room(
+        self, user_id: Optional[str], room_id: str
+    ) -> bool:
         """Intercept requests to shutdown a room. If `False` is returned, the
          room must not be shut down.
 
         Args:
-            requester: The ID of the user requesting the shutdown.
+            user_id: The ID of the user requesting the shutdown.
+                If no user ID is supplied, then the room is being shut down through
+                some mechanism other than a user's request, e.g. through a module's
+                request.
             room_id: The ID of the room.
         """
         for callback in self._check_can_shutdown_room_callbacks:
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index b471fcb064..744e98c6d0 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -349,7 +349,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                     table="devices",
                     column="user_id",
                     iterable=user_ids_to_query,
-                    keyvalues={"user_id": user_id, "hidden": False},
+                    keyvalues={"hidden": False},
                     retcols=("device_id",),
                 )
 
@@ -445,13 +445,18 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 
     @trace
     async def delete_messages_for_device(
-        self, user_id: str, device_id: Optional[str], up_to_stream_id: int
+        self,
+        user_id: str,
+        device_id: Optional[str],
+        up_to_stream_id: int,
+        limit: int,
     ) -> int:
         """
         Args:
             user_id: The recipient user_id.
             device_id: The recipient device_id.
             up_to_stream_id: Where to delete messages up to.
+            limit: maximum number of messages to delete
 
         Returns:
             The number of messages deleted.
@@ -472,12 +477,16 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 log_kv({"message": "No changes in cache since last check"})
                 return 0
 
+        ROW_ID_NAME = self.database_engine.row_id_name
+
         def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
-            sql = (
-                "DELETE FROM device_inbox"
-                " WHERE user_id = ? AND device_id = ?"
-                " AND stream_id <= ?"
-            )
+            sql = f"""
+                DELETE FROM device_inbox WHERE {ROW_ID_NAME} IN (
+                  SELECT {ROW_ID_NAME} FROM device_inbox
+                  WHERE user_id = ? AND device_id = ? AND stream_id <= ?
+                  LIMIT {limit}
+                )
+                """
             txn.execute(sql, (user_id, device_id, up_to_stream_id))
             return txn.rowcount
 
@@ -487,6 +496,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 
         log_kv({"message": f"deleted {count} messages for device", "count": count})
 
+        # In this case we don't know if we hit the limit or the delete is complete
+        # so let's not update the cache.
+        if count == limit:
+            return count
+
         # Update the cache, ensuring that we only ever increase the value
         updated_last_deleted_stream_id = self._last_device_delete_cache.get(
             (user_id, device_id), 0
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index e4162f846b..324fdfa892 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1766,14 +1766,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
             self.db_pool.simple_delete_many_txn(
                 txn,
-                table="device_inbox",
-                column="device_id",
-                values=device_ids,
-                keyvalues={"user_id": user_id},
-            )
-
-            self.db_pool.simple_delete_many_txn(
-                txn,
                 table="device_auth_providers",
                 column="device_id",
                 values=device_ids,
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index a3b4744855..57aa4921e1 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -221,12 +221,17 @@ class KeyStore(CacheInvalidationWorkerStore):
             """Processes a batch of keys to fetch, and adds the result to `keys`."""
 
             # batch_iter always returns tuples so it's safe to do len(batch)
-            sql = """
-            SELECT server_name, key_id, key_json, ts_valid_until_ms
-            FROM server_keys_json WHERE 1=0
-            """ + " OR (server_name=? AND key_id=?)" * len(
-                batch
-            )
+            where_clause = " OR (server_name=? AND key_id=?)" * len(batch)
+
+            # `server_keys_json` can have multiple entries per server (one per
+            # remote server we fetched from, if using perspectives). Order by
+            # `ts_added_ms` so the most recently fetched one always wins.
+            sql = f"""
+                SELECT server_name, key_id, key_json, ts_valid_until_ms
+                FROM server_keys_json WHERE 1=0
+                {where_clause}
+                ORDER BY ts_added_ms
+            """
 
             txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
 
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 5ee5c7ad9f..e4d10ff250 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -939,11 +939,7 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
         receipts."""
 
         def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None:
-            if isinstance(self.database_engine, PostgresEngine):
-                ROW_ID_NAME = "ctid"
-            else:
-                ROW_ID_NAME = "rowid"
-
+            ROW_ID_NAME = self.database_engine.row_id_name
             # Identify any duplicate receipts arising from
             # https://github.com/matrix-org/synapse/issues/14406.
             # The following query takes less than a minute on matrix.org.
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index 0b5b3bf03e..b1a2418cbd 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -100,6 +100,12 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM
         """Gets a string giving the server version. For example: '3.22.0'"""
         ...
 
+    @property
+    @abc.abstractmethod
+    def row_id_name(self) -> str:
+        """Gets the literal name representing a row id for this engine."""
+        ...
+
     @abc.abstractmethod
     def in_transaction(self, conn: ConnectionType) -> bool:
         """Whether the connection is currently in a transaction."""
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 05a72dc554..6309363217 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -211,6 +211,10 @@ class PostgresEngine(
         else:
             return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100)
 
+    @property
+    def row_id_name(self) -> str:
+        return "ctid"
+
     def in_transaction(self, conn: psycopg2.extensions.connection) -> bool:
         return conn.status != psycopg2.extensions.STATUS_READY
 
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index ca8c59297c..802069e1e1 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -123,6 +123,10 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]):
         """Gets a string giving the server version. For example: '3.22.0'."""
         return "%i.%i.%i" % sqlite3.sqlite_version_info
 
+    @property
+    def row_id_name(self) -> str:
+        return "rowid"
+
     def in_transaction(self, conn: sqlite3.Connection) -> bool:
         return conn.in_transaction
 
diff --git a/synapse/storage/schema/main/delta/48/group_unique_indexes.py b/synapse/storage/schema/main/delta/48/group_unique_indexes.py
index ad2da4c8af..622686d28f 100644
--- a/synapse/storage/schema/main/delta/48/group_unique_indexes.py
+++ b/synapse/storage/schema/main/delta/48/group_unique_indexes.py
@@ -14,7 +14,7 @@
 
 
 from synapse.storage.database import LoggingTransaction
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.engines import BaseDatabaseEngine
 from synapse.storage.prepare_database import get_statements
 
 FIX_INDEXES = """
@@ -37,7 +37,7 @@ CREATE INDEX group_rooms_r_idx ON group_rooms(room_id);
 
 
 def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None:
-    rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid"
+    rowid = database_engine.row_id_name
 
     # remove duplicates from group_users & group_invites tables
     cur.execute(
diff --git a/synapse/util/gai_resolver.py b/synapse/util/gai_resolver.py
index 214eb17fbc..fecf829ade 100644
--- a/synapse/util/gai_resolver.py
+++ b/synapse/util/gai_resolver.py
@@ -136,7 +136,7 @@ class GAIResolver:
 
     # The types on IHostnameResolver is incorrect in Twisted, see
     # https://twistedmatrix.com/trac/ticket/10276
-    def resolveHostName(  # type: ignore[override]
+    def resolveHostName(
         self,
         resolutionReceiver: IResolutionReceiver,
         hostName: str,
diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py
index 9e89aeb748..9b2581e51a 100644
--- a/synapse/util/task_scheduler.py
+++ b/synapse/util/task_scheduler.py
@@ -77,6 +77,7 @@ class TaskScheduler:
     LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000  # 24hrs
 
     def __init__(self, hs: "HomeServer"):
+        self._hs = hs
         self._store = hs.get_datastores().main
         self._clock = hs.get_clock()
         self._running_tasks: Set[str] = set()
@@ -97,8 +98,6 @@ class TaskScheduler:
                 "handle_scheduled_tasks",
                 self._handle_scheduled_tasks,
             )
-        else:
-            self.replication_client = hs.get_replication_command_handler()
 
     def register_action(
         self,
@@ -133,7 +132,7 @@ class TaskScheduler:
         params: Optional[JsonMapping] = None,
     ) -> str:
         """Schedule a new potentially resumable task. A function matching the specified
-        `action` should have been previously registered with `register_action`.
+        `action` should have be registered with `register_action` before the task is run.
 
         Args:
             action: the name of a previously registered action
@@ -149,11 +148,6 @@ class TaskScheduler:
         Returns:
             The id of the scheduled task
         """
-        if action not in self._actions:
-            raise Exception(
-                f"No function associated with action {action} of the scheduled task"
-            )
-
         status = TaskStatus.SCHEDULED
         if timestamp is None or timestamp < self._clock.time_msec():
             timestamp = self._clock.time_msec()
@@ -175,7 +169,7 @@ class TaskScheduler:
             if self._run_background_tasks:
                 await self._launch_task(task)
             else:
-                self.replication_client.send_new_active_task(task.id)
+                self._hs.get_replication_command_handler().send_new_active_task(task.id)
 
         return task.id
 
@@ -315,7 +309,10 @@ class TaskScheduler:
         """
         assert self._run_background_tasks
 
-        assert task.action in self._actions
+        if task.action not in self._actions:
+            raise Exception(
+                f"No function associated with action {task.action} of the scheduled task {task.id}"
+            )
         function = self._actions[task.action]
 
         async def wrapper() -> None: