summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2023-09-13 11:12:26 +0100
committerErik Johnston <erik@matrix.org>2023-09-13 11:12:26 +0100
commit3bb8cce692532ac61494c1767720d3bc9d60f08e (patch)
tree58b737f49bfe29959a8fb455df3ea57db87ac296 /synapse
parentMerge branch 'release-v1.92' into matrix-org-hotfixes (diff)
parentImprove logging of replication (#16309) (diff)
downloadsynapse-3bb8cce692532ac61494c1767720d3bc9d60f08e.tar.xz
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
Diffstat (limited to 'synapse')
-rw-r--r--synapse/_scripts/update_synapse_database.py1
-rw-r--r--synapse/api/presence.py43
-rw-r--r--synapse/appservice/api.py32
-rw-r--r--synapse/config/_base.py7
-rw-r--r--synapse/config/cas.py3
-rw-r--r--synapse/config/oembed.py2
-rw-r--r--synapse/crypto/keyring.py35
-rw-r--r--synapse/events/snapshot.py2
-rw-r--r--synapse/handlers/cas.py2
-rw-r--r--synapse/handlers/device.py84
-rw-r--r--synapse/handlers/initial_sync.py8
-rw-r--r--synapse/handlers/pagination.py12
-rw-r--r--synapse/handlers/presence.py300
-rw-r--r--synapse/handlers/room.py10
-rw-r--r--synapse/handlers/send_email.py10
-rw-r--r--synapse/handlers/sync.py16
-rw-r--r--synapse/http/federation/matrix_federation_agent.py29
-rw-r--r--synapse/logging/context.py4
-rw-r--r--synapse/logging/opentracing.py10
-rw-r--r--synapse/media/url_previewer.py4
-rw-r--r--synapse/module_api/__init__.py13
-rw-r--r--synapse/module_api/callbacks/third_party_event_rules_callbacks.py11
-rw-r--r--synapse/push/mailer.py33
-rw-r--r--synapse/replication/http/devices.py4
-rw-r--r--synapse/replication/tcp/handler.py2
-rw-r--r--synapse/replication/tcp/resource.py7
-rw-r--r--synapse/rest/__init__.py2
-rw-r--r--synapse/rest/client/account.py116
-rw-r--r--synapse/rest/client/notifications.py2
-rw-r--r--synapse/rest/synapse/client/unsubscribe.py17
-rw-r--r--synapse/storage/background_updates.py2
-rw-r--r--synapse/storage/controllers/persist_events.py8
-rw-r--r--synapse/storage/database.py5
-rw-r--r--synapse/storage/databases/main/deviceinbox.py28
-rw-r--r--synapse/storage/databases/main/devices.py34
-rw-r--r--synapse/storage/databases/main/event_push_actions.py72
-rw-r--r--synapse/storage/databases/main/keys.py236
-rw-r--r--synapse/storage/databases/main/purge_events.py4
-rw-r--r--synapse/storage/databases/main/receipts.py6
-rw-r--r--synapse/storage/engines/_base.py6
-rw-r--r--synapse/storage/engines/postgres.py4
-rw-r--r--synapse/storage/engines/sqlite.py4
-rw-r--r--synapse/storage/schema/__init__.py6
-rw-r--r--synapse/storage/schema/main/delta/48/group_unique_indexes.py4
-rw-r--r--synapse/util/async_helpers.py25
-rw-r--r--synapse/util/caches/dictionary_cache.py10
-rw-r--r--synapse/util/caches/expiringcache.py20
-rw-r--r--synapse/util/caches/ttlcache.py10
-rw-r--r--synapse/util/gai_resolver.py2
-rw-r--r--synapse/util/task_scheduler.py60
50 files changed, 860 insertions, 507 deletions
diff --git a/synapse/_scripts/update_synapse_database.py b/synapse/_scripts/update_synapse_database.py

index f97aecf8d5..992ae43881 100644 --- a/synapse/_scripts/update_synapse_database.py +++ b/synapse/_scripts/update_synapse_database.py
@@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/synapse/api/presence.py b/synapse/api/presence.py
index b80aa83cb3..b78f419994 100644 --- a/synapse/api/presence.py +++ b/synapse/api/presence.py
@@ -20,18 +20,53 @@ from synapse.api.constants import PresenceState from synapse.types import JsonDict +@attr.s(slots=True, auto_attribs=True) +class UserDevicePresenceState: + """ + Represents the current presence state of a user's device. + + user_id: The user ID. + device_id: The user's device ID. + state: The presence state, see PresenceState. + last_active_ts: Time in msec that the device last interacted with server. + last_sync_ts: Time in msec that the device last *completed* a sync + (or event stream). + """ + + user_id: str + device_id: Optional[str] + state: str + last_active_ts: int + last_sync_ts: int + + @classmethod + def default( + cls, user_id: str, device_id: Optional[str] + ) -> "UserDevicePresenceState": + """Returns a default presence state.""" + return cls( + user_id=user_id, + device_id=device_id, + state=PresenceState.OFFLINE, + last_active_ts=0, + last_sync_ts=0, + ) + + @attr.s(slots=True, frozen=True, auto_attribs=True) class UserPresenceState: """Represents the current presence state of the user. - user_id - last_active: Time in msec that the user last interacted with server. - last_federation_update: Time in msec since either a) we sent a presence + user_id: The user ID. + state: The presence state, see PresenceState. + last_active_ts: Time in msec that the user last interacted with server. + last_federation_update_ts: Time in msec since either a) we sent a presence update to other servers or b) we received a presence update, depending on if is a local user or not. - last_user_sync: Time in msec that the user last *completed* a sync + last_user_sync_ts: Time in msec that the user last *completed* a sync (or event stream). status_msg: User set status message. + currently_active: True if the user is currently syncing. """ user_id: str diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index de7a94bf26..b1523be208 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py
@@ -40,6 +40,7 @@ from synapse.appservice import ( from synapse.events import EventBase from synapse.events.utils import SerializeEventConfig, serialize_event from synapse.http.client import SimpleHttpClient, is_unknown_endpoint +from synapse.logging import opentracing from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID from synapse.util.caches.response_cache import ResponseCache @@ -125,6 +126,17 @@ class ApplicationServiceApi(SimpleHttpClient): hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS ) + def _get_headers(self, service: "ApplicationService") -> Dict[bytes, List[bytes]]: + """This makes sure we have always the auth header and opentracing headers set.""" + + # This is also ensured before in the functions. However this is needed to please + # the typechecks. + assert service.hs_token is not None + + headers = {b"Authorization": [b"Bearer " + service.hs_token.encode("ascii")]} + opentracing.inject_header_dict(headers, check_destination=False) + return headers + async def query_user(self, service: "ApplicationService", user_id: str) -> bool: if service.url is None: return False @@ -136,10 +148,11 @@ class ApplicationServiceApi(SimpleHttpClient): args = None if self.config.use_appservice_legacy_authorization: args = {"access_token": service.hs_token} + response = await self.get_json( f"{service.url}{APP_SERVICE_PREFIX}/users/{urllib.parse.quote(user_id)}", args, - headers={"Authorization": [f"Bearer {service.hs_token}"]}, + headers=self._get_headers(service), ) if response is not None: # just an empty json object return True @@ -162,10 +175,11 @@ class ApplicationServiceApi(SimpleHttpClient): args = None if self.config.use_appservice_legacy_authorization: args = {"access_token": service.hs_token} + response = await self.get_json( f"{service.url}{APP_SERVICE_PREFIX}/rooms/{urllib.parse.quote(alias)}", args, - headers={"Authorization": [f"Bearer {service.hs_token}"]}, + headers=self._get_headers(service), ) if response is not None: # just an empty json object return True @@ -203,10 +217,11 @@ class ApplicationServiceApi(SimpleHttpClient): **fields, b"access_token": service.hs_token, } + response = await self.get_json( f"{service.url}{APP_SERVICE_PREFIX}/thirdparty/{kind}/{urllib.parse.quote(protocol)}", args=args, - headers={"Authorization": [f"Bearer {service.hs_token}"]}, + headers=self._get_headers(service), ) if not isinstance(response, list): logger.warning( @@ -243,10 +258,11 @@ class ApplicationServiceApi(SimpleHttpClient): args = None if self.config.use_appservice_legacy_authorization: args = {"access_token": service.hs_token} + info = await self.get_json( f"{service.url}{APP_SERVICE_PREFIX}/thirdparty/protocol/{urllib.parse.quote(protocol)}", args, - headers={"Authorization": [f"Bearer {service.hs_token}"]}, + headers=self._get_headers(service), ) if not _is_valid_3pe_metadata(info): @@ -283,7 +299,7 @@ class ApplicationServiceApi(SimpleHttpClient): await self.post_json_get_json( uri=f"{service.url}{APP_SERVICE_PREFIX}/ping", post_json={"transaction_id": txn_id}, - headers={"Authorization": [f"Bearer {service.hs_token}"]}, + headers=self._get_headers(service), ) async def push_bulk( @@ -364,7 +380,7 @@ class ApplicationServiceApi(SimpleHttpClient): f"{service.url}{APP_SERVICE_PREFIX}/transactions/{urllib.parse.quote(str(txn_id))}", json_body=body, args=args, - headers={"Authorization": [f"Bearer {service.hs_token}"]}, + headers=self._get_headers(service), ) if logger.isEnabledFor(logging.DEBUG): logger.debug( @@ -437,7 +453,7 @@ class ApplicationServiceApi(SimpleHttpClient): response = await self.post_json_get_json( uri, body, - headers={"Authorization": [f"Bearer {service.hs_token}"]}, + headers=self._get_headers(service), ) except HttpResponseException as e: # The appservice doesn't support this endpoint. @@ -498,7 +514,7 @@ class ApplicationServiceApi(SimpleHttpClient): response = await self.post_json_get_json( uri, query, - headers={"Authorization": [f"Bearer {service.hs_token}"]}, + headers=self._get_headers(service), ) except HttpResponseException as e: # The appservice doesn't support this endpoint. diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 69a8318127..58856839e1 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py
@@ -179,8 +179,9 @@ class Config: If an integer is provided it is treated as bytes and is unchanged. - String byte sizes can have a suffix of 'K' or `M`, representing kibibytes and - mebibytes respectively. No suffix is understood as a plain byte count. + String byte sizes can have a suffix of 'K', `M`, `G` or `T`, + representing kibibytes, mebibytes, gibibytes and tebibytes respectively. + No suffix is understood as a plain byte count. Raises: TypeError, if given something other than an integer or a string @@ -189,7 +190,7 @@ class Config: if type(value) is int: # noqa: E721 return value elif isinstance(value, str): - sizes = {"K": 1024, "M": 1024 * 1024} + sizes = {"K": 1024, "M": 1024 * 1024, "G": 1024**3, "T": 1024**4} size = 1 suffix = value[-1] if suffix in sizes: diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index 6e2d9addbf..bbc8f43073 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py
@@ -57,6 +57,8 @@ class CasConfig(Config): required_attributes ) + self.cas_enable_registration = cas_config.get("enable_registration", True) + self.idp_name = cas_config.get("idp_name", "CAS") self.idp_icon = cas_config.get("idp_icon") self.idp_brand = cas_config.get("idp_brand") @@ -67,6 +69,7 @@ class CasConfig(Config): self.cas_protocol_version = None self.cas_displayname_attribute = None self.cas_required_attributes = [] + self.cas_enable_registration = False # CAS uses a legacy required attributes mapping, not the one provided by diff --git a/synapse/config/oembed.py b/synapse/config/oembed.py
index d7959639ee..59bc0b55f4 100644 --- a/synapse/config/oembed.py +++ b/synapse/config/oembed.py
@@ -30,7 +30,7 @@ class OEmbedEndpointConfig: # The API endpoint to fetch. api_endpoint: str # The patterns to match. - url_patterns: List[Pattern] + url_patterns: List[Pattern[str]] # The supported formats. formats: Optional[List[str]] diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 260aab3241..fe86f54d80 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py
@@ -23,12 +23,7 @@ from signedjson.key import ( get_verify_key, is_signing_algorithm_supported, ) -from signedjson.sign import ( - SignatureVerifyException, - encode_canonical_json, - signature_ids, - verify_signed_json, -) +from signedjson.sign import SignatureVerifyException, signature_ids, verify_signed_json from signedjson.types import VerifyKey from unpaddedbase64 import decode_base64 @@ -596,24 +591,12 @@ class BaseV2KeyFetcher(KeyFetcher): verify_key=verify_key, valid_until_ts=key_data["expired_ts"] ) - key_json_bytes = encode_canonical_json(response_json) - - await make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self.store.store_server_keys_json, - server_name=server_name, - key_id=key_id, - from_server=from_server, - ts_now_ms=time_added_ms, - ts_expires_ms=ts_valid_until_ms, - key_json_bytes=key_json_bytes, - ) - for key_id in verify_keys - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) + await self.store.store_server_keys_response( + server_name=server_name, + from_server=from_server, + ts_added_ms=time_added_ms, + verify_keys=verify_keys, + response_json=response_json, ) return verify_keys @@ -775,10 +758,6 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): keys.setdefault(server_name, {}).update(processed_response) - await self.store.store_server_signature_keys( - perspective_name, time_now_ms, added_keys - ) - return keys def _validate_perspectives_response( diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index a9e3d4e556..5bdfa3a8ac 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py
@@ -55,7 +55,6 @@ class UnpersistedEventContextBase(ABC): A method to convert an UnpersistedEventContext to an EventContext, suitable for sending to the database with the associated event. """ - pass @abstractmethod async def get_prev_state_ids( @@ -69,7 +68,6 @@ class UnpersistedEventContextBase(ABC): state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules """ - pass @attr.s(slots=True, auto_attribs=True) diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py
index a850545453..b5b8b9bd35 100644 --- a/synapse/handlers/cas.py +++ b/synapse/handlers/cas.py
@@ -70,6 +70,7 @@ class CasHandler: self._cas_protocol_version = hs.config.cas.cas_protocol_version self._cas_displayname_attribute = hs.config.cas.cas_displayname_attribute self._cas_required_attributes = hs.config.cas.cas_required_attributes + self._cas_enable_registration = hs.config.cas.cas_enable_registration self._http_client = hs.get_proxied_http_client() @@ -395,4 +396,5 @@ class CasHandler: client_redirect_url, cas_response_to_user_attributes, grandfather_existing_users, + registration_enabled=self._cas_enable_registration, ) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 763f56dfc1..9d240ad4ee 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py
@@ -43,9 +43,12 @@ from synapse.metrics.background_process_metrics import ( ) from synapse.types import ( JsonDict, + JsonMapping, + ScheduledTask, StrCollection, StreamKeyType, StreamToken, + TaskStatus, UserID, get_domain_from_id, get_verify_key_from_cross_signing_key, @@ -55,13 +58,17 @@ from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.cancellation import cancellable from synapse.util.metrics import measure_func -from synapse.util.retryutils import NotRetryingDestination +from synapse.util.retryutils import ( + NotRetryingDestination, + filter_destinations_by_retry_limiter, +) if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) +DELETE_DEVICE_MSGS_TASK_NAME = "delete_device_messages" MAX_DEVICE_DISPLAY_NAME_LEN = 100 DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000 @@ -78,6 +85,7 @@ class DeviceWorkerHandler: self._appservice_handler = hs.get_application_service_handler() self._state_storage = hs.get_storage_controllers().state self._auth_handler = hs.get_auth_handler() + self._event_sources = hs.get_event_sources() self.server_name = hs.hostname self._msc3852_enabled = hs.config.experimental.msc3852_enabled self._query_appservices_for_keys = ( @@ -386,6 +394,7 @@ class DeviceHandler(DeviceWorkerHandler): self._account_data_handler = hs.get_account_data_handler() self._storage_controllers = hs.get_storage_controllers() self.db_pool = hs.get_datastores().main.db_pool + self._task_scheduler = hs.get_task_scheduler() self.device_list_updater = DeviceListUpdater(hs, self) @@ -419,6 +428,10 @@ class DeviceHandler(DeviceWorkerHandler): self._delete_stale_devices, ) + self._task_scheduler.register_action( + self._delete_device_messages, DELETE_DEVICE_MSGS_TASK_NAME + ) + def _check_device_name_length(self, name: Optional[str]) -> None: """ Checks whether a device name is longer than the maximum allowed length. @@ -530,6 +543,7 @@ class DeviceHandler(DeviceWorkerHandler): user_id: The user to delete devices from. device_ids: The list of device IDs to delete """ + to_device_stream_id = self._event_sources.get_current_token().to_device_key try: await self.store.delete_devices(user_id, device_ids) @@ -559,12 +573,49 @@ class DeviceHandler(DeviceWorkerHandler): f"org.matrix.msc3890.local_notification_settings.{device_id}", ) + # Delete device messages asynchronously and in batches using the task scheduler + await self._task_scheduler.schedule_task( + DELETE_DEVICE_MSGS_TASK_NAME, + resource_id=device_id, + params={ + "user_id": user_id, + "device_id": device_id, + "up_to_stream_id": to_device_stream_id, + }, + ) + # Pushers are deleted after `delete_access_tokens_for_user` is called so that # modules using `on_logged_out` hook can use them if needed. await self.hs.get_pusherpool().remove_pushers_by_devices(user_id, device_ids) await self.notify_device_update(user_id, device_ids) + DEVICE_MSGS_DELETE_BATCH_LIMIT = 100 + + async def _delete_device_messages( + self, + task: ScheduledTask, + ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + """Scheduler task to delete device messages in batch of `DEVICE_MSGS_DELETE_BATCH_LIMIT`.""" + assert task.params is not None + user_id = task.params["user_id"] + device_id = task.params["device_id"] + up_to_stream_id = task.params["up_to_stream_id"] + + res = await self.store.delete_messages_for_device( + user_id=user_id, + device_id=device_id, + up_to_stream_id=up_to_stream_id, + limit=DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT, + ) + + if res < DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT: + return TaskStatus.COMPLETE, None, None + else: + # There is probably still device messages to be deleted, let's keep the task active and it will be run + # again in a subsequent scheduler loop run (probably the next one, if not too many tasks are running). + return TaskStatus.ACTIVE, None, None + async def update_device(self, user_id: str, device_id: str, content: dict) -> None: """Update the given device @@ -982,7 +1033,7 @@ class DeviceListWorkerUpdater: async def multi_user_device_resync( self, user_ids: List[str], mark_failed_as_stale: bool = True - ) -> Dict[str, Optional[JsonDict]]: + ) -> Dict[str, Optional[JsonMapping]]: """ Like `user_device_resync` but operates on multiple users **from the same origin** at once. @@ -1011,6 +1062,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): self._notifier = hs.get_notifier() self._remote_edu_linearizer = Linearizer(name="remote_device_list") + self._resync_linearizer = Linearizer(name="remote_device_resync") # user_id -> list of updates waiting to be handled. self._pending_updates: Dict[ @@ -1220,8 +1272,18 @@ class DeviceListUpdater(DeviceListWorkerUpdater): self._resync_retry_in_progress = True # Get all of the users that need resyncing. need_resync = await self.store.get_user_ids_requiring_device_list_resync() + + # Filter out users whose host is marked as "down" up front. + hosts = await filter_destinations_by_retry_limiter( + {get_domain_from_id(u) for u in need_resync}, self.clock, self.store + ) + hosts = set(hosts) + # Iterate over the set of user IDs. for user_id in need_resync: + if get_domain_from_id(user_id) not in hosts: + continue + try: # Try to resync the current user's devices list. result = (await self.multi_user_device_resync([user_id], False))[ @@ -1253,7 +1315,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): async def multi_user_device_resync( self, user_ids: List[str], mark_failed_as_stale: bool = True - ) -> Dict[str, Optional[JsonDict]]: + ) -> Dict[str, Optional[JsonMapping]]: """ Like `user_device_resync` but operates on multiple users **from the same origin** at once. @@ -1273,9 +1335,11 @@ class DeviceListUpdater(DeviceListWorkerUpdater): failed = set() # TODO(Perf): Actually batch these up for user_id in user_ids: - user_result, user_failed = await self._user_device_resync_returning_failed( - user_id - ) + async with self._resync_linearizer.queue(user_id): + ( + user_result, + user_failed, + ) = await self._user_device_resync_returning_failed(user_id) result[user_id] = user_result if user_failed: failed.add(user_id) @@ -1287,7 +1351,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): async def _user_device_resync_returning_failed( self, user_id: str - ) -> Tuple[Optional[JsonDict], bool]: + ) -> Tuple[Optional[JsonMapping], bool]: """Fetches all devices for a user and updates the device cache with them. Args: @@ -1300,6 +1364,12 @@ class DeviceListUpdater(DeviceListWorkerUpdater): e.g. due to a connection problem. - True iff the resync failed and the device list should be marked as stale. """ + # Check that we haven't gone and fetched the devices since we last + # checked if we needed to resync these device lists. + if await self.store.get_users_whose_devices_are_cached([user_id]): + cached = await self.store.get_cached_devices_for_user(user_id) + return cached, False + logger.debug("Attempting to resync the device list for %s", user_id) log_kv({"message": "Doing resync to update device list."}) # Fetch all devices for the user. diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index b3be7a86f0..5dc76ef588 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py
@@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.api.constants import ( AccountDataTypes, @@ -23,7 +23,6 @@ from synapse.api.constants import ( Membership, ) from synapse.api.errors import SynapseError -from synapse.events import EventBase from synapse.events.utils import SerializeEventConfig from synapse.events.validator import EventValidator from synapse.handlers.presence import format_user_presence_state @@ -35,7 +34,6 @@ from synapse.types import ( JsonDict, Requester, RoomStreamToken, - StateMap, StreamKeyType, StreamToken, UserID, @@ -199,9 +197,7 @@ class InitialSyncHandler: deferred_room_state = run_in_background( self._state_storage_controller.get_state_for_events, [event.event_id], - ).addCallback( - lambda states: cast(StateMap[EventBase], states[event.event_id]) - ) + ).addCallback(lambda states: states[event.event_id]) (messages, token), current_state = await make_deferred_yieldable( gather_results( diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index e5ac9096cc..19cf5a2b43 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py
@@ -713,7 +713,7 @@ class PaginationHandler: self, delete_id: str, room_id: str, - requester_user_id: str, + requester_user_id: Optional[str], new_room_user_id: Optional[str] = None, new_room_name: Optional[str] = None, message: Optional[str] = None, @@ -732,6 +732,10 @@ class PaginationHandler: requester_user_id: User who requested the action. Will be recorded as putting the room on the blocking list. + If None, the action was not manually requested but instead + triggered automatically, e.g. through a Synapse module + or some other policy. + MUST NOT be None if block=True. new_room_user_id: If set, a new room will be created with this user ID as the creator and admin, and all users in the old room will be @@ -818,7 +822,7 @@ class PaginationHandler: def start_shutdown_and_purge_room( self, room_id: str, - requester_user_id: str, + requester_user_id: Optional[str], new_room_user_id: Optional[str] = None, new_room_name: Optional[str] = None, message: Optional[str] = None, @@ -833,6 +837,10 @@ class PaginationHandler: requester_user_id: User who requested the action and put the room on the blocking list. + If None, the action was not manually requested but instead + triggered automatically, e.g. through a Synapse module + or some other policy. + MUST NOT be None if block=True. new_room_user_id: If set, a new room will be created with this user ID as the creator and admin, and all users in the old room will be diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index f31e18328b..375c7d0901 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py
@@ -13,13 +13,56 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""This module is responsible for keeping track of presence status of local +""" +This module is responsible for keeping track of presence status of local and remote users. The methods that define policy are: - PresenceHandler._update_states - PresenceHandler._handle_timeouts - should_notify + +# Tracking local presence + +For local users, presence is tracked on a per-device basis. When a user has multiple +devices the user presence state is derived by coalescing the presence from each +device: + + BUSY > ONLINE > UNAVAILABLE > OFFLINE + +The time that each device was last active and last synced is tracked in order to +automatically downgrade a device's presence state: + + A device may move from ONLINE -> UNAVAILABLE, if it has not been active for + a period of time. + + A device may go from any state -> OFFLINE, if it is not active and has not + synced for a period of time. + +The timeouts are handled using a wheel timer, which has coarse buckets. Timings +do not need to be exact. + +Generally a device's presence state is updated whenever a user syncs (via the +set_presence parameter), when the presence API is called, or if "pro-active" +events occur, including: + +* Sending an event, receipt, read marker. +* Updating typing status. + +The busy state has special status that it cannot is not downgraded by a call to +sync with a lower priority state *and* it takes a long period of time to transition +to offline. + +# Persisting (and restoring) presence + +For all users, presence is persisted on a per-user basis. Data is kept in-memory +and persisted periodically. When Synapse starts each worker loads the current +presence state and then tracks the presence stream to keep itself up-to-date. + +When restoring presence for local users a pseudo-device is created to match the +user state; this device follows the normal timeout logic (see above) and will +automatically be replaced with any information from currently available devices. + """ import abc import contextlib @@ -30,6 +73,7 @@ from contextlib import contextmanager from types import TracebackType from typing import ( TYPE_CHECKING, + AbstractSet, Any, Callable, Collection, @@ -49,7 +93,7 @@ from prometheus_client import Counter import synapse.metrics from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState from synapse.api.errors import SynapseError -from synapse.api.presence import UserPresenceState +from synapse.api.presence import UserDevicePresenceState, UserPresenceState from synapse.appservice import ApplicationService from synapse.events.presence_router import PresenceRouter from synapse.logging.context import run_in_background @@ -111,6 +155,8 @@ LAST_ACTIVE_GRANULARITY = 60 * 1000 # How long to wait until a new /events or /sync request before assuming # the client has gone. SYNC_ONLINE_TIMEOUT = 30 * 1000 +# Busy status waits longer, but does eventually go offline. +BUSY_ONLINE_TIMEOUT = 60 * 60 * 1000 # How long to wait before marking the user as idle. Compared against last active IDLE_TIMER = 5 * 60 * 1000 @@ -137,6 +183,7 @@ class BasePresenceHandler(abc.ABC): writer""" def __init__(self, hs: "HomeServer"): + self.hs = hs self.clock = hs.get_clock() self.store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() @@ -162,6 +209,7 @@ class BasePresenceHandler(abc.ABC): self.VALID_PRESENCE += (PresenceState.BUSY,) active_presence = self.store.take_presence_startup_info() + # The combined status across all user devices. self.user_to_current_state = {state.user_id: state for state in active_presence} @abc.abstractmethod @@ -426,8 +474,6 @@ class _NullContextManager(ContextManager[None]): class WorkerPresenceHandler(BasePresenceHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs - self._presence_writer_instance = hs.config.worker.writers.presence[0] # Route presence EDUs to the right worker @@ -691,7 +737,6 @@ class WorkerPresenceHandler(BasePresenceHandler): class PresenceHandler(BasePresenceHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs self.wheel_timer: WheelTimer[str] = WheelTimer() self.notifier = hs.get_notifier() @@ -708,9 +753,27 @@ class PresenceHandler(BasePresenceHandler): lambda: len(self.user_to_current_state), ) + # The per-device presence state, maps user to devices to per-device presence state. + self._user_to_device_to_current_state: Dict[ + str, Dict[Optional[str], UserDevicePresenceState] + ] = {} + now = self.clock.time_msec() if self._presence_enabled: for state in self.user_to_current_state.values(): + # Create a psuedo-device to properly handle time outs. This will + # be overridden by any "real" devices within SYNC_ONLINE_TIMEOUT. + pseudo_device_id = None + self._user_to_device_to_current_state[state.user_id] = { + pseudo_device_id: UserDevicePresenceState( + user_id=state.user_id, + device_id=pseudo_device_id, + state=state.state, + last_active_ts=state.last_active_ts, + last_sync_ts=state.last_user_sync_ts, + ) + } + self.wheel_timer.insert( now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER ) @@ -752,7 +815,7 @@ class PresenceHandler(BasePresenceHandler): # Keeps track of the number of *ongoing* syncs on other processes. # - # While any sync is ongoing on another process the user will never + # While any sync is ongoing on another process the user's device will never # go offline. # # Each process has a unique identifier and an update frequency. If @@ -981,22 +1044,21 @@ class PresenceHandler(BasePresenceHandler): timers_fired_counter.inc(len(states)) - syncing_user_ids = { - user_id - for (user_id, _), count in self._user_device_to_num_current_syncs.items() + # Set of user ID & device IDs which are currently syncing. + syncing_user_devices = { + user_id_device_id + for user_id_device_id, count in self._user_device_to_num_current_syncs.items() if count } - syncing_user_ids.update( - user_id - for user_id, _ in itertools.chain( - *self.external_process_to_current_syncs.values() - ) + syncing_user_devices.update( + itertools.chain(*self.external_process_to_current_syncs.values()) ) changes = handle_timeouts( states, is_mine_fn=self.is_mine_id, - syncing_user_ids=syncing_user_ids, + syncing_user_devices=syncing_user_devices, + user_to_devices=self._user_to_device_to_current_state, now=now, ) @@ -1016,11 +1078,26 @@ class PresenceHandler(BasePresenceHandler): bump_active_time_counter.inc() - prev_state = await self.current_state_for_user(user_id) + now = self.clock.time_msec() - new_fields: Dict[str, Any] = {"last_active_ts": self.clock.time_msec()} - if prev_state.state == PresenceState.UNAVAILABLE: - new_fields["state"] = PresenceState.ONLINE + # Update the device information & mark the device as online if it was + # unavailable. + devices = self._user_to_device_to_current_state.setdefault(user_id, {}) + device_state = devices.setdefault( + device_id, + UserDevicePresenceState.default(user_id, device_id), + ) + device_state.last_active_ts = now + if device_state.state == PresenceState.UNAVAILABLE: + device_state.state = PresenceState.ONLINE + + # Update the user state, this will always update last_active_ts and + # might update the presence state. + prev_state = await self.current_state_for_user(user_id) + new_fields: Dict[str, Any] = { + "last_active_ts": now, + "state": _combine_device_states(devices.values()), + } await self._update_states([prev_state.copy_and_replace(**new_fields)]) @@ -1132,6 +1209,12 @@ class PresenceHandler(BasePresenceHandler): if is_syncing and (user_id, device_id) not in process_presence: process_presence.add((user_id, device_id)) elif not is_syncing and (user_id, device_id) in process_presence: + devices = self._user_to_device_to_current_state.setdefault(user_id, {}) + device_state = devices.setdefault( + device_id, UserDevicePresenceState.default(user_id, device_id) + ) + device_state.last_sync_ts = sync_time_msec + new_state = prev_state.copy_and_replace( last_user_sync_ts=sync_time_msec ) @@ -1151,11 +1234,24 @@ class PresenceHandler(BasePresenceHandler): process_presence = self.external_process_to_current_syncs.pop( process_id, set() ) - prev_states = await self.current_state_for_users( - {user_id for user_id, device_id in process_presence} - ) + time_now_ms = self.clock.time_msec() + # Mark each device as having a last sync time. + updated_users = set() + for user_id, device_id in process_presence: + device_state = self._user_to_device_to_current_state.setdefault( + user_id, {} + ).setdefault( + device_id, UserDevicePresenceState.default(user_id, device_id) + ) + + device_state.last_sync_ts = time_now_ms + updated_users.add(user_id) + + # Update each user (and insert into the appropriate timers to check if + # they've gone offline). + prev_states = await self.current_state_for_users(updated_users) await self._update_states( [ prev_state.copy_and_replace(last_user_sync_ts=time_now_ms) @@ -1277,6 +1373,20 @@ class PresenceHandler(BasePresenceHandler): if prev_state.state == PresenceState.BUSY and is_sync: presence = PresenceState.BUSY + # Update the device specific information. + devices = self._user_to_device_to_current_state.setdefault(user_id, {}) + device_state = devices.setdefault( + device_id, + UserDevicePresenceState.default(user_id, device_id), + ) + device_state.state = presence + device_state.last_active_ts = now + if is_sync: + device_state.last_sync_ts = now + + # Based on the state of each user's device calculate the new presence state. + presence = _combine_device_states(devices.values()) + new_fields = {"state": presence} if presence == PresenceState.ONLINE or presence == PresenceState.BUSY: @@ -1873,7 +1983,8 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): def handle_timeouts( user_states: List[UserPresenceState], is_mine_fn: Callable[[str], bool], - syncing_user_ids: Set[str], + syncing_user_devices: AbstractSet[Tuple[str, Optional[str]]], + user_to_devices: Dict[str, Dict[Optional[str], UserDevicePresenceState]], now: int, ) -> List[UserPresenceState]: """Checks the presence of users that have timed out and updates as @@ -1882,7 +1993,8 @@ def handle_timeouts( Args: user_states: List of UserPresenceState's to check. is_mine_fn: Function that returns if a user_id is ours - syncing_user_ids: Set of user_ids with active syncs. + syncing_user_devices: A set of (user ID, device ID) tuples with active syncs.. + user_to_devices: A map of user ID to device ID to UserDevicePresenceState. now: Current time in ms. Returns: @@ -1891,9 +2003,16 @@ def handle_timeouts( changes = {} # Actual changes we need to notify people about for state in user_states: - is_mine = is_mine_fn(state.user_id) - - new_state = handle_timeout(state, is_mine, syncing_user_ids, now) + user_id = state.user_id + is_mine = is_mine_fn(user_id) + + new_state = handle_timeout( + state, + is_mine, + syncing_user_devices, + user_to_devices.get(user_id, {}), + now, + ) if new_state: changes[state.user_id] = new_state @@ -1901,14 +2020,19 @@ def handle_timeouts( def handle_timeout( - state: UserPresenceState, is_mine: bool, syncing_user_ids: Set[str], now: int + state: UserPresenceState, + is_mine: bool, + syncing_device_ids: AbstractSet[Tuple[str, Optional[str]]], + user_devices: Dict[Optional[str], UserDevicePresenceState], + now: int, ) -> Optional[UserPresenceState]: """Checks the presence of the user to see if any of the timers have elapsed Args: - state + state: UserPresenceState to check. is_mine: Whether the user is ours - syncing_user_ids: Set of user_ids with active syncs. + syncing_user_devices: A set of (user ID, device ID) tuples with active syncs.. + user_devices: A map of device ID to UserDevicePresenceState. now: Current time in ms. Returns: @@ -1919,34 +2043,63 @@ def handle_timeout( return None changed = False - user_id = state.user_id if is_mine: - if state.state == PresenceState.ONLINE: - if now - state.last_active_ts > IDLE_TIMER: - # Currently online, but last activity ages ago so auto - # idle - state = state.copy_and_replace(state=PresenceState.UNAVAILABLE) - changed = True - elif now - state.last_active_ts > LAST_ACTIVE_GRANULARITY: - # So that we send down a notification that we've - # stopped updating. + # Check per-device whether the device should be considered idle or offline + # due to timeouts. + device_changed = False + offline_devices = [] + for device_id, device_state in user_devices.items(): + if device_state.state == PresenceState.ONLINE: + if now - device_state.last_active_ts > IDLE_TIMER: + # Currently online, but last activity ages ago so auto + # idle + device_state.state = PresenceState.UNAVAILABLE + device_changed = True + + # If there are have been no sync for a while (and none ongoing), + # set presence to offline. + if (state.user_id, device_id) not in syncing_device_ids: + # If the user has done something recently but hasn't synced, + # don't set them as offline. + sync_or_active = max( + device_state.last_sync_ts, device_state.last_active_ts + ) + + # Implementations aren't meant to timeout a device with a busy + # state, but it needs to timeout *eventually* or else the user + # will be stuck in that state. + online_timeout = ( + BUSY_ONLINE_TIMEOUT + if device_state.state == PresenceState.BUSY + else SYNC_ONLINE_TIMEOUT + ) + if now - sync_or_active > online_timeout: + # Mark the device as going offline. + offline_devices.append(device_id) + device_changed = True + + # Offline devices are not needed and do not add information. + for device_id in offline_devices: + user_devices.pop(device_id) + + # If the presence state of the devices changed, then (maybe) update + # the user's overall presence state. + if device_changed: + new_presence = _combine_device_states(user_devices.values()) + if new_presence != state.state: + state = state.copy_and_replace(state=new_presence) changed = True + if now - state.last_active_ts > LAST_ACTIVE_GRANULARITY: + # So that we send down a notification that we've + # stopped updating. + changed = True + if now - state.last_federation_update_ts > FEDERATION_PING_INTERVAL: # Need to send ping to other servers to ensure they don't # timeout and set us to offline changed = True - - # If there are have been no sync for a while (and none ongoing), - # set presence to offline - if user_id not in syncing_user_ids: - # If the user has done something recently but hasn't synced, - # don't set them as offline. - sync_or_active = max(state.last_user_sync_ts, state.last_active_ts) - if now - sync_or_active > SYNC_ONLINE_TIMEOUT: - state = state.copy_and_replace(state=PresenceState.OFFLINE) - changed = True else: # We expect to be poked occasionally by the other side. # This is to protect against forgetful/buggy servers, so that @@ -2021,6 +2174,13 @@ def handle_update( new_state = new_state.copy_and_replace(last_federation_update_ts=now) federation_ping = True + if new_state.state == PresenceState.BUSY: + wheel_timer.insert( + now=now, + obj=user_id, + then=new_state.last_user_sync_ts + BUSY_ONLINE_TIMEOUT, + ) + else: wheel_timer.insert( now=now, @@ -2036,6 +2196,46 @@ def handle_update( return new_state, persist_and_notify, federation_ping +PRESENCE_BY_PRIORITY = { + PresenceState.BUSY: 4, + PresenceState.ONLINE: 3, + PresenceState.UNAVAILABLE: 2, + PresenceState.OFFLINE: 1, +} + + +def _combine_device_states( + device_states: Iterable[UserDevicePresenceState], +) -> str: + """ + Find the device to use presence information from. + + Orders devices by priority, then last_active_ts. + + Args: + device_states: An iterable of device presence states + + Return: + The combined presence state. + """ + + # Based on (all) the user's devices calculate the new presence state. + presence = PresenceState.OFFLINE + last_active_ts = -1 + + # Find the device to use the presence state of based on the presence priority, + # but tie-break with how recently the device has been seen. + for device_state in device_states: + if (PRESENCE_BY_PRIORITY[device_state.state], device_state.last_active_ts) > ( + PRESENCE_BY_PRIORITY[presence], + last_active_ts, + ): + presence = device_state.state + last_active_ts = device_state.last_active_ts + + return presence + + async def get_interested_parties( store: DataStore, presence_router: PresenceRouter, states: List[UserPresenceState] ) -> Tuple[Dict[str, List[UserPresenceState]], Dict[str, List[UserPresenceState]]]: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 0513e28aab..7a762c8511 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py
@@ -1787,7 +1787,7 @@ class RoomShutdownHandler: async def shutdown_room( self, room_id: str, - requester_user_id: str, + requester_user_id: Optional[str], new_room_user_id: Optional[str] = None, new_room_name: Optional[str] = None, message: Optional[str] = None, @@ -1811,6 +1811,10 @@ class RoomShutdownHandler: requester_user_id: User who requested the action and put the room on the blocking list. + If None, the action was not manually requested but instead + triggered automatically, e.g. through a Synapse module + or some other policy. + MUST NOT be None if block=True. new_room_user_id: If set, a new room will be created with this user ID as the creator and admin, and all users in the old room will be @@ -1863,6 +1867,10 @@ class RoomShutdownHandler: # Action the block first (even if the room doesn't exist yet) if block: + if requester_user_id is None: + raise ValueError( + "shutdown_room: block=True not allowed when requester_user_id is None." + ) # This will work even if the room is already blocked, but that is # desirable in case the first attempt at blocking the room failed below. await self.store.block_room(room_id, requester_user_id) diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py
index 05e21509de..4f5fe62fe8 100644 --- a/synapse/handlers/send_email.py +++ b/synapse/handlers/send_email.py
@@ -17,7 +17,7 @@ import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from io import BytesIO -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional from pkg_resources import parse_version @@ -151,6 +151,7 @@ class SendEmailHandler: app_name: str, html: str, text: str, + additional_headers: Optional[Dict[str, str]] = None, ) -> None: """Send a multipart email with the given information. @@ -160,6 +161,7 @@ class SendEmailHandler: app_name: The app name to include in the From header. html: The HTML content to include in the email. text: The plain text content to include in the email. + additional_headers: A map of additional headers to include. """ try: from_string = self._from % {"app": app_name} @@ -181,6 +183,7 @@ class SendEmailHandler: multipart_msg["To"] = email_address multipart_msg["Date"] = email.utils.formatdate() multipart_msg["Message-ID"] = email.utils.make_msgid() + # Discourage automatic responses to Synapse's emails. # Per RFC 3834, automatic responses should not be sent if the "Auto-Submitted" # header is present with any value other than "no". See @@ -194,6 +197,11 @@ class SendEmailHandler: # https://stackoverflow.com/a/25324691/5252017 # https://stackoverflow.com/a/61646381/5252017 multipart_msg["X-Auto-Response-Suppress"] = "All" + + if additional_headers: + for header, value in additional_headers.items(): + multipart_msg[header] = value + multipart_msg.attach(text_part) multipart_msg.attach(html_part) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 60a9f341b5..0ccd7d250c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py
@@ -40,6 +40,7 @@ from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase +from synapse.handlers.device import DELETE_DEVICE_MSGS_TASK_NAME from synapse.handlers.relations import BundledAggregations from synapse.logging import issue9533_logger from synapse.logging.context import current_context @@ -268,6 +269,7 @@ class SyncHandler: self._storage_controllers = hs.get_storage_controllers() self._state_storage_controller = self._storage_controllers.state self._device_handler = hs.get_device_handler() + self._task_scheduler = hs.get_task_scheduler() self.should_calculate_push_rules = hs.config.push.enable_push @@ -360,11 +362,19 @@ class SyncHandler: # (since we now know that the device has received them) if since_token is not None: since_stream_id = since_token.to_device_key - deleted = await self.store.delete_messages_for_device( - sync_config.user.to_string(), sync_config.device_id, since_stream_id + # Delete device messages asynchronously and in batches using the task scheduler + await self._task_scheduler.schedule_task( + DELETE_DEVICE_MSGS_TASK_NAME, + resource_id=sync_config.device_id, + params={ + "user_id": sync_config.user.to_string(), + "device_id": sync_config.device_id, + "up_to_stream_id": since_stream_id, + }, ) logger.debug( - "Deleted %d to-device messages up to %d", deleted, since_stream_id + "Deletion of to-device messages up to %d scheduled", + since_stream_id, ) if timeout == 0 or since_token is None or full_state: diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 91a24efcd0..a3a396bb37 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py
@@ -399,15 +399,34 @@ class MatrixHostnameEndpoint: if port or _is_ip_literal(host): return [Server(host, port or 8448)] + # Check _matrix-fed._tcp SRV record. logger.debug("Looking up SRV record for %s", host.decode(errors="replace")) + server_list = await self._srv_resolver.resolve_service( + b"_matrix-fed._tcp." + host + ) + + if server_list: + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Got %s from SRV lookup for %s", + ", ".join(map(str, server_list)), + host.decode(errors="replace"), + ) + return server_list + + # No _matrix-fed._tcp SRV record, fallback to legacy _matrix._tcp SRV record. + logger.debug( + "Looking up deprecated SRV record for %s", host.decode(errors="replace") + ) server_list = await self._srv_resolver.resolve_service(b"_matrix._tcp." + host) if server_list: - logger.debug( - "Got %s from SRV lookup for %s", - ", ".join(map(str, server_list)), - host.decode(errors="replace"), - ) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Got %s from deprecated SRV lookup for %s", + ", ".join(map(str, server_list)), + host.decode(errors="replace"), + ) return server_list # No SRV records, so we fallback to host and 8448 diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 64c6ae4512..bf7e311026 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py
@@ -728,7 +728,7 @@ async def _unwrap_awaitable(awaitable: Awaitable[R]) -> R: @overload -def preserve_fn( # type: ignore[misc] +def preserve_fn( f: Callable[P, Awaitable[R]], ) -> Callable[P, "defer.Deferred[R]"]: # The `type: ignore[misc]` above suppresses @@ -756,7 +756,7 @@ def preserve_fn( @overload -def run_in_background( # type: ignore[misc] +def run_in_background( f: Callable[P, Awaitable[R]], *args: P.args, **kwargs: P.kwargs ) -> "defer.Deferred[R]": # The `type: ignore[misc]` above suppresses diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 5c3045e197..4454fe29a5 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py
@@ -991,11 +991,7 @@ def trace_with_opname( if not opentracing: return func - # type-ignore: mypy seems to be confused by the ParamSpecs here. - # I think the problem is https://github.com/python/mypy/issues/12909 - return _custom_sync_async_decorator( - func, _wrapping_logic # type: ignore[arg-type] - ) + return _custom_sync_async_decorator(func, _wrapping_logic) return _decorator @@ -1040,9 +1036,7 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]: set_tag(SynapseTags.FUNC_KWARGS, str(kwargs)) yield - # type-ignore: mypy seems to be confused by the ParamSpecs here. - # I think the problem is https://github.com/python/mypy/issues/12909 - return _custom_sync_async_decorator(func, _wrapping_logic) # type: ignore[arg-type] + return _custom_sync_async_decorator(func, _wrapping_logic) @contextlib.contextmanager diff --git a/synapse/media/url_previewer.py b/synapse/media/url_previewer.py
index 70b32cee17..9b5a3dd5f4 100644 --- a/synapse/media/url_previewer.py +++ b/synapse/media/url_previewer.py
@@ -846,9 +846,7 @@ def _is_media(content_type: str) -> bool: def _is_html(content_type: str) -> bool: content_type = content_type.lower() - return content_type.startswith("text/html") or content_type.startswith( - "application/xhtml" - ) + return content_type.startswith(("text/html", "application/xhtml")) def _is_json(content_type: str) -> bool: diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 2f00a7ba20..d6efe10a28 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py
@@ -1730,6 +1730,19 @@ class ModuleApi: room_alias_str = room_alias.to_string() if room_alias else None return room_id, room_alias_str + async def delete_room(self, room_id: str) -> None: + """ + Schedules the deletion of a room from Synapse's database. + + If the room is already being deleted, this method does nothing. + This method does not wait for the room to be deleted. + + Added in Synapse v1.89.0. + """ + # Future extensions to this method might want to e.g. allow use of `force_purge`. + # TODO In the future we should make sure this is persistent. + self._hs.get_pagination_handler().start_shutdown_and_purge_room(room_id, None) + async def set_displayname( self, user_id: UserID, diff --git a/synapse/module_api/callbacks/third_party_event_rules_callbacks.py b/synapse/module_api/callbacks/third_party_event_rules_callbacks.py
index 911f37ba42..ecaeef3511 100644 --- a/synapse/module_api/callbacks/third_party_event_rules_callbacks.py +++ b/synapse/module_api/callbacks/third_party_event_rules_callbacks.py
@@ -40,7 +40,7 @@ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[ [str, StateMap[EventBase], str], Awaitable[bool] ] ON_NEW_EVENT_CALLBACK = Callable[[EventBase, StateMap[EventBase]], Awaitable] -CHECK_CAN_SHUTDOWN_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]] +CHECK_CAN_SHUTDOWN_ROOM_CALLBACK = Callable[[Optional[str], str], Awaitable[bool]] CHECK_CAN_DEACTIVATE_USER_CALLBACK = Callable[[str, bool], Awaitable[bool]] ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable] ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable] @@ -429,12 +429,17 @@ class ThirdPartyEventRulesModuleApiCallbacks: "Failed to run module API callback %s: %s", callback, e ) - async def check_can_shutdown_room(self, user_id: str, room_id: str) -> bool: + async def check_can_shutdown_room( + self, user_id: Optional[str], room_id: str + ) -> bool: """Intercept requests to shutdown a room. If `False` is returned, the room must not be shut down. Args: - requester: The ID of the user requesting the shutdown. + user_id: The ID of the user requesting the shutdown. + If no user ID is supplied, then the room is being shut down through + some mechanism other than a user's request, e.g. through a module's + request. room_id: The ID of the room. """ for callback in self._check_can_shutdown_room_callbacks: diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 79e0627b6a..b6cad18c2d 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py
@@ -298,20 +298,26 @@ class Mailer: notifs_by_room, state_by_room, notif_events, reason ) + unsubscribe_link = self._make_unsubscribe_link(user_id, app_id, email_address) + template_vars: TemplateVars = { "user_display_name": user_display_name, - "unsubscribe_link": self._make_unsubscribe_link( - user_id, app_id, email_address - ), + "unsubscribe_link": unsubscribe_link, "summary_text": summary_text, "rooms": rooms, "reason": reason, } - await self.send_email(email_address, summary_text, template_vars) + await self.send_email( + email_address, summary_text, template_vars, unsubscribe_link + ) async def send_email( - self, email_address: str, subject: str, extra_template_vars: TemplateVars + self, + email_address: str, + subject: str, + extra_template_vars: TemplateVars, + unsubscribe_link: Optional[str] = None, ) -> None: """Send an email with the given information and template text""" template_vars: TemplateVars = { @@ -330,6 +336,23 @@ class Mailer: app_name=self.app_name, html=html_text, text=plain_text, + # Include the List-Unsubscribe header which some clients render in the UI. + # Per RFC 2369, this can be a URL or mailto URL. See + # https://www.rfc-editor.org/rfc/rfc2369.html#section-3.2 + # + # It is preferred to use email, but Synapse doesn't support incoming email. + # + # Also include the List-Unsubscribe-Post header from RFC 8058. See + # https://www.rfc-editor.org/rfc/rfc8058.html#section-3.1 + # + # Note that many email clients will not render the unsubscribe link + # unless DKIM, etc. is properly setup. + additional_headers={ + "List-Unsubscribe-Post": "List-Unsubscribe=One-Click", + "List-Unsubscribe": f"<{unsubscribe_link}>", + } + if unsubscribe_link + else None, ) async def _get_room_vars( diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index 209833d287..b8198e059c 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py
@@ -20,7 +20,7 @@ from twisted.web.server import Request from synapse.http.server import HttpServer from synapse.logging.opentracing import active_span from synapse.replication.http._base import ReplicationEndpoint -from synapse.types import JsonDict +from synapse.types import JsonDict, JsonMapping if TYPE_CHECKING: from synapse.server import HomeServer @@ -82,7 +82,7 @@ class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint): async def _handle_request( # type: ignore[override] self, request: Request, content: JsonDict - ) -> Tuple[int, Dict[str, Optional[JsonDict]]]: + ) -> Tuple[int, Dict[str, Optional[JsonMapping]]]: user_ids: List[str] = content["user_ids"] logger.info("Resync for %r", user_ids) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index d9045d7b73..5642666411 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py
@@ -644,7 +644,7 @@ class ReplicationCommandHandler: [stream.parse_row(row) for row in rows], ) - logger.info("Caught up with stream '%s' to %i", stream_name, cmd.new_token) + logger.info("Caught up with stream '%s' to %i", stream_name, cmd.new_token) # We've now caught up to position sent to us, notify handler. await self._replication_data_handler.on_position( diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 347467d863..1d9a29d22e 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py
@@ -191,7 +191,12 @@ class ReplicationStreamer: if updates: logger.info( - "Streaming: %s -> %s", stream.NAME, updates[-1][0] + "Streaming: %s -> %s (limited: %s, updates: %s, max token: %s)", + stream.NAME, + updates[-1][0], + limited, + len(updates), + current_token, ) stream_updates_counter.labels(stream.NAME).inc(len(updates)) diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index df0845edb2..1be9c47c61 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py
@@ -123,7 +123,7 @@ class ClientRestResource(JsonResource): if is_main_process: report_event.register_servlets(hs, client_resource) openid.register_servlets(hs, client_resource) - notifications.register_servlets(hs, client_resource) + notifications.register_servlets(hs, client_resource) devices.register_servlets(hs, client_resource) if is_main_process: thirdparty.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index 679ab9f266..49cd0805fd 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py
@@ -179,85 +179,81 @@ class PasswordRestServlet(RestServlet): # # In the second case, we require a password to confirm their identity. - requester = None - if self.auth.has_access_token(request): - requester = await self.auth.get_user_by_req(request) - try: + try: + requester = None + if self.auth.has_access_token(request): + requester = await self.auth.get_user_by_req(request) params, session_id = await self.auth_handler.validate_user_via_ui_auth( requester, request, - body.dict(exclude_unset=True), + body.dict(exclude_unset=True, exclude={"new_password"}), "modify your account password", ) - except InteractiveAuthIncompleteError as e: - # The user needs to provide more steps to complete auth, but - # they're not required to provide the password again. - # - # If a password is available now, hash the provided password and - # store it for later. - if new_password: - new_password_hash = await self.auth_handler.hash(new_password) - await self.auth_handler.set_session_data( - e.session_id, - UIAuthSessionDataConstants.PASSWORD_HASH, - new_password_hash, - ) - raise - user_id = requester.user.to_string() - else: - try: + user_id = requester.user.to_string() + else: result, params, session_id = await self.auth_handler.check_ui_auth( [[LoginType.EMAIL_IDENTITY]], request, - body.dict(exclude_unset=True), + body.dict(exclude_unset=True, exclude={"new_password"}), "modify your account password", ) - except InteractiveAuthIncompleteError as e: - # The user needs to provide more steps to complete auth, but - # they're not required to provide the password again. - # - # If a password is available now, hash the provided password and - # store it for later. - if new_password: - new_password_hash = await self.auth_handler.hash(new_password) - await self.auth_handler.set_session_data( - e.session_id, - UIAuthSessionDataConstants.PASSWORD_HASH, - new_password_hash, + + if LoginType.EMAIL_IDENTITY in result: + threepid = result[LoginType.EMAIL_IDENTITY] + if "medium" not in threepid or "address" not in threepid: + raise SynapseError(500, "Malformed threepid") + if threepid["medium"] == "email": + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. + # (See add_threepid in synapse/handlers/auth.py) + try: + threepid["address"] = validate_email(threepid["address"]) + except ValueError as e: + raise SynapseError(400, str(e)) + # if using email, we must know about the email they're authing with! + threepid_user_id = await self.datastore.get_user_id_by_threepid( + threepid["medium"], threepid["address"] ) + if not threepid_user_id: + raise SynapseError( + 404, "Email address not found", Codes.NOT_FOUND + ) + user_id = threepid_user_id + else: + logger.error("Auth succeeded but no known type! %r", result.keys()) + raise SynapseError(500, "", Codes.UNKNOWN) + + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth, but + # they're not required to provide the password again. + # + # If a password is available now, hash the provided password and + # store it for later. We only do this if we don't already have the + # password hash stored, to avoid repeatedly hashing the password. + + if not new_password: raise - if LoginType.EMAIL_IDENTITY in result: - threepid = result[LoginType.EMAIL_IDENTITY] - if "medium" not in threepid or "address" not in threepid: - raise SynapseError(500, "Malformed threepid") - if threepid["medium"] == "email": - # For emails, canonicalise the address. - # We store all email addresses canonicalised in the DB. - # (See add_threepid in synapse/handlers/auth.py) - try: - threepid["address"] = validate_email(threepid["address"]) - except ValueError as e: - raise SynapseError(400, str(e)) - # if using email, we must know about the email they're authing with! - threepid_user_id = await self.datastore.get_user_id_by_threepid( - threepid["medium"], threepid["address"] - ) - if not threepid_user_id: - raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) - user_id = threepid_user_id - else: - logger.error("Auth succeeded but no known type! %r", result.keys()) - raise SynapseError(500, "", Codes.UNKNOWN) + existing_session_password_hash = await self.auth_handler.get_session_data( + e.session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None + ) + if existing_session_password_hash: + raise + + new_password_hash = await self.auth_handler.hash(new_password) + await self.auth_handler.set_session_data( + e.session_id, + UIAuthSessionDataConstants.PASSWORD_HASH, + new_password_hash, + ) + raise # If we have a password in this request, prefer it. Otherwise, use the # password hash from an earlier request. if new_password: password_hash: Optional[str] = await self.auth_handler.hash(new_password) elif session_id is not None: - password_hash = await self.auth_handler.get_session_data( - session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None - ) + password_hash = existing_session_password_hash else: # UI validation was skipped, but the request did not include a new # password. diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py
index ea10042569..e7fe1332e7 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py
@@ -36,6 +36,8 @@ logger = logging.getLogger(__name__) class NotificationsServlet(RestServlet): PATTERNS = client_patterns("/notifications$") + CATEGORY = "Client API requests" + def __init__(self, hs: "HomeServer"): super().__init__() self.store = hs.get_datastores().main diff --git a/synapse/rest/synapse/client/unsubscribe.py b/synapse/rest/synapse/client/unsubscribe.py
index 60321018f9..050fd7bba1 100644 --- a/synapse/rest/synapse/client/unsubscribe.py +++ b/synapse/rest/synapse/client/unsubscribe.py
@@ -38,6 +38,10 @@ class UnsubscribeResource(DirectServeHtmlResource): self.macaroon_generator = hs.get_macaroon_generator() async def _async_render_GET(self, request: SynapseRequest) -> None: + """ + Handle a user opening an unsubscribe link in the browser, either via an + HTML/Text email or via the List-Unsubscribe header. + """ token = parse_string(request, "access_token", required=True) app_id = parse_string(request, "app_id", required=True) pushkey = parse_string(request, "pushkey", required=True) @@ -62,3 +66,16 @@ class UnsubscribeResource(DirectServeHtmlResource): 200, UnsubscribeResource.SUCCESS_HTML, ) + + async def _async_render_POST(self, request: SynapseRequest) -> None: + """ + Handle a mail user agent POSTing to the unsubscribe URL via the + List-Unsubscribe & List-Unsubscribe-Post headers. + """ + + # TODO Assert that the body has a single field + + # Assert the body has form encoded key/value pair of + # List-Unsubscribe=One-Click. + + await self._async_render_GET(request) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 7619f405fa..99ebd96f84 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py
@@ -62,7 +62,6 @@ class Constraint(metaclass=abc.ABCMeta): @abc.abstractmethod def make_check_clause(self, table: str) -> str: """Returns an SQL expression that checks the row passes the constraint.""" - pass @abc.abstractmethod def make_constraint_clause_postgres(self) -> str: @@ -70,7 +69,6 @@ class Constraint(metaclass=abc.ABCMeta): Only used on Postgres DBs """ - pass @attr.s(auto_attribs=True) diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index abd1d149db..6864f93090 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py
@@ -154,12 +154,13 @@ class _UpdateCurrentStateTask: _EventPersistQueueTask = Union[_PersistEventsTask, _UpdateCurrentStateTask] +_PersistResult = TypeVar("_PersistResult") @attr.s(auto_attribs=True, slots=True) -class _EventPersistQueueItem: +class _EventPersistQueueItem(Generic[_PersistResult]): task: _EventPersistQueueTask - deferred: ObservableDeferred + deferred: ObservableDeferred[_PersistResult] parent_opentracing_span_contexts: List = attr.ib(factory=list) """A list of opentracing spans waiting for this batch""" @@ -168,9 +169,6 @@ class _EventPersistQueueItem: """The opentracing span under which the persistence actually happened""" -_PersistResult = TypeVar("_PersistResult") - - class _EventPeristenceQueue(Generic[_PersistResult]): """Queues up tasks so that they can be processed with only one concurrent transaction per room. diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 55ac313f33..6c5fcdcec3 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py
@@ -422,10 +422,11 @@ class LoggingTransaction: return self._do_execute( # TODO: is it safe for values to be Iterable[Iterable[Any]] here? # https://www.psycopg.org/docs/extras.html?highlight=execute_batch#psycopg2.extras.execute_values says values should be Sequence[Sequence] - lambda the_sql: execute_values( - self.txn, the_sql, values, template=template, fetch=fetch + lambda the_sql, the_values: execute_values( + self.txn, the_sql, the_values, template=template, fetch=fetch ), sql, + values, ) def execute(self, sql: str, parameters: SQLQueryParameters = ()) -> None: diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 25f70fee84..0be12f0e06 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py
@@ -349,7 +349,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): table="devices", column="user_id", iterable=user_ids_to_query, - keyvalues={"user_id": user_id, "hidden": False}, + keyvalues={"hidden": False}, retcols=("device_id",), ) @@ -445,13 +445,18 @@ class DeviceInboxWorkerStore(SQLBaseStore): @trace async def delete_messages_for_device( - self, user_id: str, device_id: Optional[str], up_to_stream_id: int + self, + user_id: str, + device_id: Optional[str], + up_to_stream_id: int, + limit: int, ) -> int: """ Args: user_id: The recipient user_id. device_id: The recipient device_id. up_to_stream_id: Where to delete messages up to. + limit: maximum number of messages to delete Returns: The number of messages deleted. @@ -472,12 +477,16 @@ class DeviceInboxWorkerStore(SQLBaseStore): log_kv({"message": "No changes in cache since last check"}) return 0 + ROW_ID_NAME = self.database_engine.row_id_name + def delete_messages_for_device_txn(txn: LoggingTransaction) -> int: - sql = ( - "DELETE FROM device_inbox" - " WHERE user_id = ? AND device_id = ?" - " AND stream_id <= ?" - ) + sql = f""" + DELETE FROM device_inbox WHERE {ROW_ID_NAME} IN ( + SELECT {ROW_ID_NAME} FROM device_inbox + WHERE user_id = ? AND device_id = ? AND stream_id <= ? + LIMIT {limit} + ) + """ txn.execute(sql, (user_id, device_id, up_to_stream_id)) return txn.rowcount @@ -487,6 +496,11 @@ class DeviceInboxWorkerStore(SQLBaseStore): log_kv({"message": f"deleted {count} messages for device", "count": count}) + # In this case we don't know if we hit the limit or the delete is complete + # so let's not update the cache. + if count == limit: + return count + # Update the cache, ensuring that we only ever increase the value updated_last_deleted_stream_id = self._last_device_delete_cache.get( (user_id, device_id), 0 diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index e4162f846b..70faf4b1ec 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -759,18 +759,10 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): mapping of user_id -> device_id -> device_info. """ unique_user_ids = user_ids | {user_id for user_id, _ in user_and_device_ids} - user_map = await self.get_device_list_last_stream_id_for_remotes( - list(unique_user_ids) - ) - # We go and check if any of the users need to have their device lists - # resynced. If they do then we remove them from the cached list. - users_needing_resync = await self.get_user_ids_requiring_device_list_resync( + user_ids_in_cache = await self.get_users_whose_devices_are_cached( unique_user_ids ) - user_ids_in_cache = { - user_id for user_id, stream_id in user_map.items() if stream_id - } - users_needing_resync user_ids_not_in_cache = unique_user_ids - user_ids_in_cache # First fetch all the users which all devices are to be returned. @@ -792,6 +784,22 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): return user_ids_not_in_cache, results + async def get_users_whose_devices_are_cached( + self, user_ids: StrCollection + ) -> Set[str]: + """Checks which of the given users we have cached the devices for.""" + user_map = await self.get_device_list_last_stream_id_for_remotes(user_ids) + + # We go and check if any of the users need to have their device lists + # resynced. If they do then we remove them from the cached list. + users_needing_resync = await self.get_user_ids_requiring_device_list_resync( + user_ids + ) + user_ids_in_cache = { + user_id for user_id, stream_id in user_map.items() if stream_id + } - users_needing_resync + return user_ids_in_cache + @cached(num_args=2, tree=True) async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict: content = await self.db_pool.simple_select_one_onecol( @@ -1766,14 +1774,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.db_pool.simple_delete_many_txn( txn, - table="device_inbox", - column="device_id", - values=device_ids, - keyvalues={"user_id": user_id}, - ) - - self.db_pool.simple_delete_many_txn( - txn, table="device_auth_providers", column="device_id", values=device_ids, diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 07bda7d6be..b958a39aeb 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py
@@ -1740,42 +1740,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # We sleep to ensure that we don't overwhelm the DB. await self._clock.sleep(1.0) - -class EventPushActionsStore(EventPushActionsWorkerStore): - EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): - super().__init__(database, db_conn, hs) - - self.db_pool.updates.register_background_index_update( - self.EPA_HIGHLIGHT_INDEX, - index_name="event_push_actions_u_highlight", - table="event_push_actions", - columns=["user_id", "stream_ordering"], - ) - - self.db_pool.updates.register_background_index_update( - "event_push_actions_highlights_index", - index_name="event_push_actions_highlights_index", - table="event_push_actions", - columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], - where_clause="highlight=1", - ) - - # Add index to make deleting old push actions faster. - self.db_pool.updates.register_background_index_update( - "event_push_actions_stream_highlight_index", - index_name="event_push_actions_stream_highlight_index", - table="event_push_actions", - columns=["highlight", "stream_ordering"], - where_clause="highlight=0", - ) - async def get_push_actions_for_user( self, user_id: str, @@ -1834,6 +1798,42 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ] +class EventPushActionsStore(EventPushActionsWorkerStore): + EPA_HIGHLIGHT_INDEX = "epa_highlight_index" + + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_index_update( + self.EPA_HIGHLIGHT_INDEX, + index_name="event_push_actions_u_highlight", + table="event_push_actions", + columns=["user_id", "stream_ordering"], + ) + + self.db_pool.updates.register_background_index_update( + "event_push_actions_highlights_index", + index_name="event_push_actions_highlights_index", + table="event_push_actions", + columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], + where_clause="highlight=1", + ) + + # Add index to make deleting old push actions faster. + self.db_pool.updates.register_background_index_update( + "event_push_actions_stream_highlight_index", + index_name="event_push_actions_stream_highlight_index", + table="event_push_actions", + columns=["highlight", "stream_ordering"], + where_clause="highlight=0", + ) + + def _action_has_highlight(actions: Collection[Union[Mapping, str]]) -> bool: for action in actions: if not isinstance(action, dict): diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index a3b4744855..41563371dc 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py
@@ -16,14 +16,17 @@ import itertools import json import logging -from typing import Dict, Iterable, Mapping, Optional, Tuple +from typing import Dict, Iterable, Optional, Tuple +from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes from unpaddedbase64 import decode_base64 +from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote from synapse.storage.types import Cursor +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter @@ -36,162 +39,84 @@ db_binary_type = memoryview class KeyStore(CacheInvalidationWorkerStore): """Persistence for signature verification keys""" - @cached() - def _get_server_signature_key( - self, server_name_and_key_id: Tuple[str, str] - ) -> FetchKeyResult: - raise NotImplementedError() - - @cachedList( - cached_method_name="_get_server_signature_key", - list_name="server_name_and_key_ids", - ) - async def get_server_signature_keys( - self, server_name_and_key_ids: Iterable[Tuple[str, str]] - ) -> Dict[Tuple[str, str], FetchKeyResult]: - """ - Args: - server_name_and_key_ids: - iterable of (server_name, key-id) tuples to fetch keys for - - Returns: - A map from (server_name, key_id) -> FetchKeyResult, or None if the - key is unknown - """ - keys = {} - - def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None: - """Processes a batch of keys to fetch, and adds the result to `keys`.""" - - # batch_iter always returns tuples so it's safe to do len(batch) - sql = """ - SELECT server_name, key_id, verify_key, ts_valid_until_ms - FROM server_signature_keys WHERE 1=0 - """ + " OR (server_name=? AND key_id=?)" * len( - batch - ) - - txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) - - for row in txn: - server_name, key_id, key_bytes, ts_valid_until_ms = row - - if ts_valid_until_ms is None: - # Old keys may be stored with a ts_valid_until_ms of null, - # in which case we treat this as if it was set to `0`, i.e. - # it won't match key requests that define a minimum - # `ts_valid_until_ms`. - ts_valid_until_ms = 0 - - keys[(server_name, key_id)] = FetchKeyResult( - verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), - valid_until_ts=ts_valid_until_ms, - ) - - def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]: - for batch in batch_iter(server_name_and_key_ids, 50): - _get_keys(txn, batch) - return keys - - return await self.db_pool.runInteraction("get_server_signature_keys", _txn) - - async def store_server_signature_keys( + async def store_server_keys_response( self, + server_name: str, from_server: str, ts_added_ms: int, - verify_keys: Mapping[Tuple[str, str], FetchKeyResult], + verify_keys: Dict[str, FetchKeyResult], + response_json: JsonDict, ) -> None: - """Stores NACL verification keys for remote servers. + """Stores the keys for the given server that we got from `from_server`. + Args: - from_server: Where the verification keys were looked up - ts_added_ms: The time to record that the key was added - verify_keys: - keys to be stored. Each entry is a triplet of - (server_name, key_id, key). + server_name: The owner of the keys + from_server: Which server we got the keys from + ts_added_ms: When we're adding the keys + verify_keys: The decoded keys + response_json: The full *signed* response JSON that contains the keys. """ - key_values = [] - value_values = [] - invalidations = [] - for (server_name, key_id), fetch_result in verify_keys.items(): - key_values.append((server_name, key_id)) - value_values.append( - ( - from_server, - ts_added_ms, - fetch_result.valid_until_ts, - db_binary_type(fetch_result.verify_key.encode()), - ) - ) - # invalidate takes a tuple corresponding to the params of - # _get_server_signature_key. _get_server_signature_key only takes one - # param, which is itself the 2-tuple (server_name, key_id). - invalidations.append((server_name, key_id)) - await self.db_pool.simple_upsert_many( - table="server_signature_keys", - key_names=("server_name", "key_id"), - key_values=key_values, - value_names=( - "from_server", - "ts_added_ms", - "ts_valid_until_ms", - "verify_key", - ), - value_values=value_values, - desc="store_server_signature_keys", - ) + key_json_bytes = encode_canonical_json(response_json) + + def store_server_keys_response_txn(txn: LoggingTransaction) -> None: + self.db_pool.simple_upsert_many_txn( + txn, + table="server_signature_keys", + key_names=("server_name", "key_id"), + key_values=[(server_name, key_id) for key_id in verify_keys], + value_names=( + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "verify_key", + ), + value_values=[ + ( + from_server, + ts_added_ms, + fetch_result.valid_until_ts, + db_binary_type(fetch_result.verify_key.encode()), + ) + for fetch_result in verify_keys.values() + ], + ) - invalidate = self._get_server_signature_key.invalidate - for i in invalidations: - invalidate((i,)) + self.db_pool.simple_upsert_many_txn( + txn, + table="server_keys_json", + key_names=("server_name", "key_id", "from_server"), + key_values=[ + (server_name, key_id, from_server) for key_id in verify_keys + ], + value_names=( + "ts_added_ms", + "ts_valid_until_ms", + "key_json", + ), + value_values=[ + ( + ts_added_ms, + fetch_result.valid_until_ts, + db_binary_type(key_json_bytes), + ) + for fetch_result in verify_keys.values() + ], + ) - async def store_server_keys_json( - self, - server_name: str, - key_id: str, - from_server: str, - ts_now_ms: int, - ts_expires_ms: int, - key_json_bytes: bytes, - ) -> None: - """Stores the JSON bytes for a set of keys from a server - The JSON should be signed by the originating server, the intermediate - server, and by this server. Updates the value for the - (server_name, key_id, from_server) triplet if one already existed. - Args: - server_name: The name of the server. - key_id: The identifier of the key this JSON is for. - from_server: The server this JSON was fetched from. - ts_now_ms: The time now in milliseconds. - ts_valid_until_ms: The time when this json stops being valid. - key_json_bytes: The encoded JSON. - """ - await self.db_pool.simple_upsert( - table="server_keys_json", - keyvalues={ - "server_name": server_name, - "key_id": key_id, - "from_server": from_server, - }, - values={ - "server_name": server_name, - "key_id": key_id, - "from_server": from_server, - "ts_added_ms": ts_now_ms, - "ts_valid_until_ms": ts_expires_ms, - "key_json": db_binary_type(key_json_bytes), - }, - desc="store_server_keys_json", - ) + # invalidate takes a tuple corresponding to the params of + # _get_server_keys_json. _get_server_keys_json only takes one + # param, which is itself the 2-tuple (server_name, key_id). + for key_id in verify_keys: + self._invalidate_cache_and_stream( + txn, self._get_server_keys_json, ((server_name, key_id),) + ) + self._invalidate_cache_and_stream( + txn, self.get_server_key_json_for_remote, (server_name, key_id) + ) - # invalidate takes a tuple corresponding to the params of - # _get_server_keys_json. _get_server_keys_json only takes one - # param, which is itself the 2-tuple (server_name, key_id). - await self.invalidate_cache_and_stream( - "_get_server_keys_json", ((server_name, key_id),) - ) - await self.invalidate_cache_and_stream( - "get_server_key_json_for_remote", (server_name, key_id) + await self.db_pool.runInteraction( + "store_server_keys_response", store_server_keys_response_txn ) @cached() @@ -221,12 +146,17 @@ class KeyStore(CacheInvalidationWorkerStore): """Processes a batch of keys to fetch, and adds the result to `keys`.""" # batch_iter always returns tuples so it's safe to do len(batch) - sql = """ - SELECT server_name, key_id, key_json, ts_valid_until_ms - FROM server_keys_json WHERE 1=0 - """ + " OR (server_name=? AND key_id=?)" * len( - batch - ) + where_clause = " OR (server_name=? AND key_id=?)" * len(batch) + + # `server_keys_json` can have multiple entries per server (one per + # remote server we fetched from, if using perspectives). Order by + # `ts_added_ms` so the most recently fetched one always wins. + sql = f""" + SELECT server_name, key_id, key_json, ts_valid_until_ms + FROM server_keys_json WHERE 1=0 + {where_clause} + ORDER BY ts_added_ms + """ txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index b52f48cf04..dea0e0458c 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py
@@ -450,10 +450,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "e2e_room_keys", "event_push_summary", "pusher_throttle", - "insertion_events", - "insertion_event_extremities", - "insertion_event_edges", - "batch_events", "room_account_data", "room_tags", # "rooms" happens last, to keep the foreign keys in the other tables diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 5ee5c7ad9f..e4d10ff250 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py
@@ -939,11 +939,7 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): receipts.""" def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None: - if isinstance(self.database_engine, PostgresEngine): - ROW_ID_NAME = "ctid" - else: - ROW_ID_NAME = "rowid" - + ROW_ID_NAME = self.database_engine.row_id_name # Identify any duplicate receipts arising from # https://github.com/matrix-org/synapse/issues/14406. # The following query takes less than a minute on matrix.org. diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index 0b5b3bf03e..b1a2418cbd 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py
@@ -100,6 +100,12 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM """Gets a string giving the server version. For example: '3.22.0'""" ... + @property + @abc.abstractmethod + def row_id_name(self) -> str: + """Gets the literal name representing a row id for this engine.""" + ... + @abc.abstractmethod def in_transaction(self, conn: ConnectionType) -> bool: """Whether the connection is currently in a transaction.""" diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 05a72dc554..6309363217 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py
@@ -211,6 +211,10 @@ class PostgresEngine( else: return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100) + @property + def row_id_name(self) -> str: + return "ctid" + def in_transaction(self, conn: psycopg2.extensions.connection) -> bool: return conn.status != psycopg2.extensions.STATUS_READY diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index ca8c59297c..802069e1e1 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py
@@ -123,6 +123,10 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]): """Gets a string giving the server version. For example: '3.22.0'.""" return "%i.%i.%i" % sqlite3.sqlite_version_info + @property + def row_id_name(self) -> str: + return "rowid" + def in_transaction(self, conn: sqlite3.Connection) -> bool: return conn.in_transaction diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 422f11f59e..5b50bd66bc 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 81 # remember to update the list below when updating +SCHEMA_VERSION = 82 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -117,6 +117,10 @@ Changes in SCHEMA_VERSION = 80 Changes in SCHEMA_VERSION = 81 - The event_txn_id is no longer written to for new events. + +Changes in SCHEMA_VERSION = 82 + - The insertion_events, insertion_event_extremities, insertion_event_edges, and + batch_events tables are no longer purged in preparation for their removal. """ diff --git a/synapse/storage/schema/main/delta/48/group_unique_indexes.py b/synapse/storage/schema/main/delta/48/group_unique_indexes.py
index ad2da4c8af..622686d28f 100644 --- a/synapse/storage/schema/main/delta/48/group_unique_indexes.py +++ b/synapse/storage/schema/main/delta/48/group_unique_indexes.py
@@ -14,7 +14,7 @@ from synapse.storage.database import LoggingTransaction -from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.prepare_database import get_statements FIX_INDEXES = """ @@ -37,7 +37,7 @@ CREATE INDEX group_rooms_r_idx ON group_rooms(room_id); def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: - rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid" + rowid = database_engine.row_id_name # remove duplicates from group_users & group_invites tables cur.execute( diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 943ad54456..0cbeb0c365 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py
@@ -19,6 +19,7 @@ import collections import inspect import itertools import logging +import typing from contextlib import asynccontextmanager from typing import ( Any, @@ -29,6 +30,7 @@ from typing import ( Collection, Coroutine, Dict, + Generator, Generic, Hashable, Iterable, @@ -398,7 +400,7 @@ class _LinearizerEntry: # The number of things executing. count: int # Deferreds for the things blocked from executing. - deferreds: collections.OrderedDict + deferreds: typing.OrderedDict["defer.Deferred[None]", Literal[1]] class Linearizer: @@ -717,30 +719,25 @@ def timeout_deferred( return new_d -# This class can't be generic because it uses slots with attrs. -# See: https://github.com/python-attrs/attrs/issues/313 @attr.s(slots=True, frozen=True, auto_attribs=True) -class DoneAwaitable: # should be: Generic[R] +class DoneAwaitable(Awaitable[R]): """Simple awaitable that returns the provided value.""" - value: Any # should be: R + value: R - def __await__(self) -> Any: - return self - - def __iter__(self) -> "DoneAwaitable": - return self - - def __next__(self) -> None: - raise StopIteration(self.value) + def __await__(self) -> Generator[Any, None, R]: + yield None + return self.value def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]: """Convert a value to an awaitable if not already an awaitable.""" if inspect.isawaitable(value): - assert isinstance(value, Awaitable) return value + # For some reason mypy doesn't deduce that value is not Awaitable here, even though + # inspect.isawaitable returns a TypeGuard. + assert not isinstance(value, Awaitable) return DoneAwaitable(value) diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 5eaf70c7ab..2fbc7b1e6c 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py
@@ -14,7 +14,7 @@ import enum import logging import threading -from typing import Any, Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union +from typing import Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union import attr from typing_extensions import Literal @@ -33,10 +33,8 @@ DKT = TypeVar("DKT") DV = TypeVar("DV") -# This class can't be generic because it uses slots with attrs. -# See: https://github.com/python-attrs/attrs/issues/313 @attr.s(slots=True, frozen=True, auto_attribs=True) -class DictionaryEntry: # should be: Generic[DKT, DV]. +class DictionaryEntry(Generic[DKT, DV]): """Returned when getting an entry from the cache If `full` is true then `known_absent` will be the empty set. @@ -50,8 +48,8 @@ class DictionaryEntry: # should be: Generic[DKT, DV]. """ full: bool - known_absent: Set[Any] # should be: Set[DKT] - value: Dict[Any, Any] # should be: Dict[DKT, DV] + known_absent: Set[DKT] + value: Dict[DKT, DV] def __len__(self) -> int: return len(self.value) diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 01ad02af67..8e4c34039d 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py
@@ -14,7 +14,7 @@ import logging from collections import OrderedDict -from typing import Any, Generic, Optional, TypeVar, Union, overload +from typing import Any, Generic, Iterable, Optional, TypeVar, Union, overload import attr from typing_extensions import Literal @@ -73,7 +73,7 @@ class ExpiringCache(Generic[KT, VT]): self._expiry_ms = expiry_ms self._reset_expiry_on_get = reset_expiry_on_get - self._cache: OrderedDict[KT, _CacheEntry] = OrderedDict() + self._cache: OrderedDict[KT, _CacheEntry[VT]] = OrderedDict() self.iterable = iterable @@ -100,7 +100,10 @@ class ExpiringCache(Generic[KT, VT]): while self._max_size and len(self) > self._max_size: _key, value = self._cache.popitem(last=False) if self.iterable: - self.metrics.inc_evictions(EvictionReason.size, len(value.value)) + # type-ignore, here and below: if self.iterable is true, then the value + # type VT should be Sized (i.e. have a __len__ method). We don't enforce + # this via the type system at present. + self.metrics.inc_evictions(EvictionReason.size, len(value.value)) # type: ignore[arg-type] else: self.metrics.inc_evictions(EvictionReason.size) @@ -134,7 +137,7 @@ class ExpiringCache(Generic[KT, VT]): return default if self.iterable: - self.metrics.inc_evictions(EvictionReason.invalidation, len(value.value)) + self.metrics.inc_evictions(EvictionReason.invalidation, len(value.value)) # type: ignore[arg-type] else: self.metrics.inc_evictions(EvictionReason.invalidation) @@ -182,7 +185,7 @@ class ExpiringCache(Generic[KT, VT]): for k in keys_to_delete: value = self._cache.pop(k) if self.iterable: - self.metrics.inc_evictions(EvictionReason.time, len(value.value)) + self.metrics.inc_evictions(EvictionReason.time, len(value.value)) # type: ignore[arg-type] else: self.metrics.inc_evictions(EvictionReason.time) @@ -195,7 +198,8 @@ class ExpiringCache(Generic[KT, VT]): def __len__(self) -> int: if self.iterable: - return sum(len(entry.value) for entry in self._cache.values()) + g: Iterable[int] = (len(entry.value) for entry in self._cache.values()) # type: ignore[arg-type] + return sum(g) else: return len(self._cache) @@ -218,6 +222,6 @@ class ExpiringCache(Generic[KT, VT]): @attr.s(slots=True, auto_attribs=True) -class _CacheEntry: +class _CacheEntry(Generic[VT]): time: int - value: Any + value: VT diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index f6b3ee31e4..48a6e4a906 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py
@@ -35,10 +35,10 @@ class TTLCache(Generic[KT, VT]): def __init__(self, cache_name: str, timer: Callable[[], float] = time.time): # map from key to _CacheEntry - self._data: Dict[KT, _CacheEntry] = {} + self._data: Dict[KT, _CacheEntry[KT, VT]] = {} # the _CacheEntries, sorted by expiry time - self._expiry_list: SortedList[_CacheEntry] = SortedList() + self._expiry_list: SortedList[_CacheEntry[KT, VT]] = SortedList() self._timer = timer @@ -160,11 +160,11 @@ class TTLCache(Generic[KT, VT]): @attr.s(frozen=True, slots=True, auto_attribs=True) -class _CacheEntry: # Should be Generic[KT, VT]. See python-attrs/attrs#313 +class _CacheEntry(Generic[KT, VT]): """TTLCache entry""" # expiry_time is the first attribute, so that entries are sorted by expiry. expiry_time: float ttl: float - key: Any # should be KT - value: Any # should be VT + key: KT + value: VT diff --git a/synapse/util/gai_resolver.py b/synapse/util/gai_resolver.py
index 214eb17fbc..fecf829ade 100644 --- a/synapse/util/gai_resolver.py +++ b/synapse/util/gai_resolver.py
@@ -136,7 +136,7 @@ class GAIResolver: # The types on IHostnameResolver is incorrect in Twisted, see # https://twistedmatrix.com/trac/ticket/10276 - def resolveHostName( # type: ignore[override] + def resolveHostName( self, resolutionReceiver: IResolutionReceiver, hostName: str, diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py
index 9e89aeb748..b7de201bde 100644 --- a/synapse/util/task_scheduler.py +++ b/synapse/util/task_scheduler.py
@@ -19,6 +19,7 @@ from prometheus_client import Gauge from twisted.python.failure import Failure +from synapse.logging.context import nested_logging_context from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import JsonMapping, ScheduledTask, TaskStatus from synapse.util.stringutils import random_string @@ -77,6 +78,7 @@ class TaskScheduler: LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000 # 24hrs def __init__(self, hs: "HomeServer"): + self._hs = hs self._store = hs.get_datastores().main self._clock = hs.get_clock() self._running_tasks: Set[str] = set() @@ -97,8 +99,6 @@ class TaskScheduler: "handle_scheduled_tasks", self._handle_scheduled_tasks, ) - else: - self.replication_client = hs.get_replication_command_handler() def register_action( self, @@ -133,7 +133,7 @@ class TaskScheduler: params: Optional[JsonMapping] = None, ) -> str: """Schedule a new potentially resumable task. A function matching the specified - `action` should have been previously registered with `register_action`. + `action` should have be registered with `register_action` before the task is run. Args: action: the name of a previously registered action @@ -149,11 +149,6 @@ class TaskScheduler: Returns: The id of the scheduled task """ - if action not in self._actions: - raise Exception( - f"No function associated with action {action} of the scheduled task" - ) - status = TaskStatus.SCHEDULED if timestamp is None or timestamp < self._clock.time_msec(): timestamp = self._clock.time_msec() @@ -175,7 +170,7 @@ class TaskScheduler: if self._run_background_tasks: await self._launch_task(task) else: - self.replication_client.send_new_active_task(task.id) + self._hs.get_replication_command_handler().send_new_active_task(task.id) return task.id @@ -315,30 +310,34 @@ class TaskScheduler: """ assert self._run_background_tasks - assert task.action in self._actions + if task.action not in self._actions: + raise Exception( + f"No function associated with action {task.action} of the scheduled task {task.id}" + ) function = self._actions[task.action] async def wrapper() -> None: - try: - (status, result, error) = await function(task) - except Exception: - f = Failure() - logger.error( - f"scheduled task {task.id} failed", - exc_info=(f.type, f.value, f.getTracebackObject()), + with nested_logging_context(task.id): + try: + (status, result, error) = await function(task) + except Exception: + f = Failure() + logger.error( + f"scheduled task {task.id} failed", + exc_info=(f.type, f.value, f.getTracebackObject()), + ) + status = TaskStatus.FAILED + result = None + error = f.getErrorMessage() + + await self._store.update_scheduled_task( + task.id, + self._clock.time_msec(), + status=status, + result=result, + error=error, ) - status = TaskStatus.FAILED - result = None - error = f.getErrorMessage() - - await self._store.update_scheduled_task( - task.id, - self._clock.time_msec(), - status=status, - result=result, - error=error, - ) - self._running_tasks.remove(task.id) + self._running_tasks.remove(task.id) if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS: return @@ -356,5 +355,4 @@ class TaskScheduler: self._running_tasks.add(task.id) await self.update_task(task.id, status=TaskStatus.ACTIVE) - description = f"{task.id}-{task.action}" - run_as_background_process(description, wrapper) + run_as_background_process(task.action, wrapper)