diff --git a/synapse/api/auth/msc3861_delegated.py b/synapse/api/auth/msc3861_delegated.py
index 3a516093f5..14cba50c90 100644
--- a/synapse/api/auth/msc3861_delegated.py
+++ b/synapse/api/auth/msc3861_delegated.py
@@ -20,6 +20,7 @@ from authlib.oauth2.auth import encode_client_secret_basic, encode_client_secret
from authlib.oauth2.rfc7523 import ClientSecretJWT, PrivateKeyJWT, private_key_jwt_sign
from authlib.oauth2.rfc7662 import IntrospectionToken
from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
+from prometheus_client import Histogram
from twisted.web.client import readBody
from twisted.web.http_headers import Headers
@@ -46,6 +47,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+introspection_response_timer = Histogram(
+ "synapse_api_auth_delegated_introspection_response",
+ "Time taken to get a response for an introspection request",
+ ["code"],
+)
+
+
# Scope as defined by MSC2967
# https://github.com/matrix-org/matrix-spec-proposals/pull/2967
SCOPE_MATRIX_API = "urn:matrix:org.matrix.msc2967.client:api:*"
@@ -190,14 +198,26 @@ class MSC3861DelegatedAuth(BaseAuth):
# Do the actual request
# We're not using the SimpleHttpClient util methods as we don't want to
# check the HTTP status code, and we do the body encoding ourselves.
- response = await self._http_client.request(
- method="POST",
- uri=uri,
- data=body.encode("utf-8"),
- headers=headers,
- )
- resp_body = await make_deferred_yieldable(readBody(response))
+ start_time = self._clock.time()
+ try:
+ response = await self._http_client.request(
+ method="POST",
+ uri=uri,
+ data=body.encode("utf-8"),
+ headers=headers,
+ )
+
+ resp_body = await make_deferred_yieldable(readBody(response))
+ except Exception:
+ end_time = self._clock.time()
+ introspection_response_timer.labels("ERR").observe(end_time - start_time)
+ raise
+
+ end_time = self._clock.time()
+ introspection_response_timer.labels(response.code).observe(
+ end_time - start_time
+ )
if response.code < 200 or response.code >= 300:
raise HttpResponseException(
@@ -226,7 +246,7 @@ class MSC3861DelegatedAuth(BaseAuth):
return introspection_token
async def is_server_admin(self, requester: Requester) -> bool:
- return "urn:synapse:admin:*" in requester.scope
+ return SCOPE_SYNAPSE_ADMIN in requester.scope
async def get_user_by_req(
self,
@@ -243,6 +263,25 @@ class MSC3861DelegatedAuth(BaseAuth):
# so that we don't provision the user if they don't have enough permission:
requester = await self.get_user_by_access_token(access_token, allow_expired)
+ # Allow impersonation by an admin user using `_oidc_admin_impersonate_user_id` query parameter
+ if request.args is not None:
+ user_id_params = request.args.get(b"_oidc_admin_impersonate_user_id")
+ if user_id_params:
+ if await self.is_server_admin(requester):
+ user_id_str = user_id_params[0].decode("ascii")
+ impersonated_user_id = UserID.from_string(user_id_str)
+ logging.info(f"Admin impersonation of user {user_id_str}")
+ requester = create_requester(
+ user_id=impersonated_user_id,
+ scope=[SCOPE_MATRIX_API],
+ authenticated_entity=requester.user.to_string(),
+ )
+ else:
+ raise AuthError(
+ 401,
+ "Impersonation not possible by a non admin user",
+ )
+
# Deny the request if the user account is locked.
if not allow_locked and await self.store.get_user_locked_status(
requester.user.to_string()
@@ -270,14 +309,14 @@ class MSC3861DelegatedAuth(BaseAuth):
# XXX: This is a temporary solution so that the admin API can be called by
# the OIDC provider. This will be removed once we have OIDC client
# credentials grant support in matrix-authentication-service.
- logging.info("Admin toked used")
+ logging.info("Admin token used")
# XXX: that user doesn't exist and won't be provisioned.
# This is mostly fine for admin calls, but we should also think about doing
# requesters without a user_id.
admin_user = UserID("__oidc_admin", self._hostname)
return create_requester(
user_id=admin_user,
- scope=["urn:synapse:admin:*"],
+ scope=[SCOPE_SYNAPSE_ADMIN],
)
try:
@@ -399,3 +438,16 @@ class MSC3861DelegatedAuth(BaseAuth):
scope=scope,
is_guest=(has_guest_scope and not has_user_scope),
)
+
+ def invalidate_cached_tokens(self, keys: List[str]) -> None:
+ """
+ Invalidate the entry(s) in the introspection token cache corresponding to the given key
+ """
+ for key in keys:
+ self._token_cache.invalidate(key)
+
+ def invalidate_token_cache(self) -> None:
+ """
+ Invalidate the entire token cache.
+ """
+ self._token_cache.invalidate_all()
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index dc79efcc14..d25e3548e0 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -91,6 +91,7 @@ from synapse.storage.databases.main.state import StateGroupWorkerStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.databases.main.stream import StreamWorkerStore
from synapse.storage.databases.main.tags import TagsWorkerStore
+from synapse.storage.databases.main.task_scheduler import TaskSchedulerWorkerStore
from synapse.storage.databases.main.transactions import TransactionWorkerStore
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
from synapse.storage.databases.main.user_directory import UserDirectoryStore
@@ -144,6 +145,7 @@ class GenericWorkerStore(
TransactionWorkerStore,
LockStore,
SessionStore,
+ TaskSchedulerWorkerStore,
):
# Properties that multiple storage classes define. Tell mypy what the
# expected type is.
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index ac9449b18f..277ea4675b 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -173,6 +173,13 @@ class MSC3861:
("enable_registration",),
)
+ # We only need to test the user consent version, as if it must be set if the user_consent section was present in the config
+ if root.consent.user_consent_version is not None:
+ raise ConfigError(
+ "User consent cannot be enabled when OAuth delegation is enabled",
+ ("user_consent",),
+ )
+
if (
root.oidc.oidc_enabled
or root.saml2.saml2_enabled
@@ -216,6 +223,12 @@ class MSC3861:
("session_lifetime",),
)
+ if root.registration.enable_3pid_changes:
+ raise ConfigError(
+ "enable_3pid_changes cannot be enabled when OAuth delegation is enabled",
+ ("enable_3pid_changes",),
+ )
+
@attr.s(auto_attribs=True, frozen=True, slots=True)
class MSC3866Config:
diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py
index 77c1d1dc8e..574d6afb95 100644
--- a/synapse/config/oidc.py
+++ b/synapse/config/oidc.py
@@ -280,6 +280,20 @@ def _parse_oidc_config_dict(
for x in oidc_config.get("attribute_requirements", [])
]
+ # Read from either `client_secret_path` or `client_secret`. If both exist, error.
+ client_secret = oidc_config.get("client_secret")
+ client_secret_path = oidc_config.get("client_secret_path")
+ if client_secret_path is not None:
+ if client_secret is None:
+ client_secret = read_file(
+ client_secret_path, config_path + ("client_secret_path",)
+ ).rstrip("\n")
+ else:
+ raise ConfigError(
+ "Cannot specify both client_secret and client_secret_path",
+ config_path + ("client_secret",),
+ )
+
return OidcProviderConfig(
idp_id=idp_id,
idp_name=oidc_config.get("idp_name", "OIDC"),
@@ -288,7 +302,7 @@ def _parse_oidc_config_dict(
discover=oidc_config.get("discover", True),
issuer=oidc_config["issuer"],
client_id=oidc_config["client_id"],
- client_secret=oidc_config.get("client_secret"),
+ client_secret=client_secret,
client_secret_jwt_key=client_secret_jwt_key,
client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
pkce_method=oidc_config.get("pkce_method", "auto"),
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index df1d83dfaa..b8ad6fbc06 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -133,7 +133,16 @@ class RegistrationConfig(Config):
self.enable_set_displayname = config.get("enable_set_displayname", True)
self.enable_set_avatar_url = config.get("enable_set_avatar_url", True)
- self.enable_3pid_changes = config.get("enable_3pid_changes", True)
+
+ # The default value of enable_3pid_changes is True, unless msc3861 is enabled.
+ msc3861_enabled = (
+ (config.get("experimental_features") or {})
+ .get("msc3861", {})
+ .get("enabled", False)
+ )
+ self.enable_3pid_changes = config.get(
+ "enable_3pid_changes", not msc3861_enabled
+ )
self.disable_msisdn_registration = config.get(
"disable_msisdn_registration", False
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 2b93b8c621..29cd45550a 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -60,6 +60,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.events.validator import EventValidator
from synapse.federation.federation_client import InvalidResponseError
+from synapse.handlers.pagination import PURGE_PAGINATION_LOCK_NAME
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import nested_logging_context
from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace
@@ -152,6 +153,7 @@ class FederationHandler:
self._device_handler = hs.get_device_handler()
self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
self._notifier = hs.get_notifier()
+ self._worker_locks = hs.get_worker_locks_handler()
self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
hs
@@ -200,7 +202,7 @@ class FederationHandler:
@trace
@tag_args
async def maybe_backfill(
- self, room_id: str, current_depth: int, limit: int
+ self, room_id: str, current_depth: int, limit: int, record_time: bool = True
) -> bool:
"""Checks the database to see if we should backfill before paginating,
and if so do.
@@ -213,21 +215,25 @@ class FederationHandler:
limit: The number of events that the pagination request will
return. This is used as part of the heuristic to decide if we
should back paginate.
+ record_time: Whether to record the time it takes to backfill.
Returns:
True if we actually tried to backfill something, otherwise False.
"""
# Starting the processing time here so we can include the room backfill
# linearizer lock queue in the timing
- processing_start_time = self.clock.time_msec()
+ processing_start_time = self.clock.time_msec() if record_time else 0
async with self._room_backfill.queue(room_id):
- return await self._maybe_backfill_inner(
- room_id,
- current_depth,
- limit,
- processing_start_time=processing_start_time,
- )
+ async with self._worker_locks.acquire_read_write_lock(
+ PURGE_PAGINATION_LOCK_NAME, room_id, write=False
+ ):
+ return await self._maybe_backfill_inner(
+ room_id,
+ current_depth,
+ limit,
+ processing_start_time=processing_start_time,
+ )
@trace
@tag_args
@@ -305,12 +311,21 @@ class FederationHandler:
# of history that extends all the way back to where we are currently paginating
# and it's within the 100 events that are returned from `/backfill`.
if not sorted_backfill_points and current_depth != MAX_DEPTH:
+ # Check that we actually have later backfill points, if not just return.
+ have_later_backfill_points = await self.store.get_backfill_points_in_room(
+ room_id=room_id,
+ current_depth=MAX_DEPTH,
+ limit=1,
+ )
+ if not have_later_backfill_points:
+ return False
+
logger.debug(
"_maybe_backfill_inner: all backfill points are *after* current depth. Trying again with later backfill points."
)
run_as_background_process(
"_maybe_backfill_inner_anyway_with_max_depth",
- self._maybe_backfill_inner,
+ self.maybe_backfill,
room_id=room_id,
# We use `MAX_DEPTH` so that we find all backfill points next
# time (all events are below the `MAX_DEPTH`)
@@ -319,7 +334,7 @@ class FederationHandler:
# We don't want to start another timing observation from this
# nested recursive call. The top-most call can record the time
# overall otherwise the smaller one will throw off the results.
- processing_start_time=None,
+ record_time=False,
)
# We return `False` because we're backfilling in the background and there is
# no new events immediately for the caller to know about yet.
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 1be6ebc6d9..e5ac9096cc 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -487,155 +487,150 @@ class PaginationHandler:
room_token = from_token.room_key
- async with self._worker_locks.acquire_read_write_lock(
- PURGE_PAGINATION_LOCK_NAME, room_id, write=False
- ):
- (membership, member_event_id) = (None, None)
- if not use_admin_priviledge:
- (
- membership,
- member_event_id,
- ) = await self.auth.check_user_in_room_or_world_readable(
- room_id, requester, allow_departed_users=True
+ (membership, member_event_id) = (None, None)
+ if not use_admin_priviledge:
+ (
+ membership,
+ member_event_id,
+ ) = await self.auth.check_user_in_room_or_world_readable(
+ room_id, requester, allow_departed_users=True
+ )
+
+ if pagin_config.direction == Direction.BACKWARDS:
+ # if we're going backwards, we might need to backfill. This
+ # requires that we have a topo token.
+ if room_token.topological:
+ curr_topo = room_token.topological
+ else:
+ curr_topo = await self.store.get_current_topological_token(
+ room_id, room_token.stream
)
- if pagin_config.direction == Direction.BACKWARDS:
- # if we're going backwards, we might need to backfill. This
- # requires that we have a topo token.
- if room_token.topological:
- curr_topo = room_token.topological
- else:
- curr_topo = await self.store.get_current_topological_token(
- room_id, room_token.stream
- )
+ # If they have left the room then clamp the token to be before
+ # they left the room, to save the effort of loading from the
+ # database.
+ if (
+ pagin_config.direction == Direction.BACKWARDS
+ and not use_admin_priviledge
+ and membership == Membership.LEAVE
+ ):
+ # This is only None if the room is world_readable, in which case
+ # "Membership.JOIN" would have been returned and we should never hit
+ # this branch.
+ assert member_event_id
- # If they have left the room then clamp the token to be before
- # they left the room, to save the effort of loading from the
- # database.
- if (
- pagin_config.direction == Direction.BACKWARDS
- and not use_admin_priviledge
- and membership == Membership.LEAVE
- ):
- # This is only None if the room is world_readable, in which case
- # "Membership.JOIN" would have been returned and we should never hit
- # this branch.
- assert member_event_id
+ leave_token = await self.store.get_topological_token_for_event(
+ member_event_id
+ )
+ assert leave_token.topological is not None
- leave_token = await self.store.get_topological_token_for_event(
- member_event_id
+ if leave_token.topological < curr_topo:
+ from_token = from_token.copy_and_replace(
+ StreamKeyType.ROOM, leave_token
)
- assert leave_token.topological is not None
- if leave_token.topological < curr_topo:
- from_token = from_token.copy_and_replace(
- StreamKeyType.ROOM, leave_token
- )
+ to_room_key = None
+ if pagin_config.to_token:
+ to_room_key = pagin_config.to_token.room_key
+
+ # Initially fetch the events from the database. With any luck, we can return
+ # these without blocking on backfill (handled below).
+ events, next_key = await self.store.paginate_room_events(
+ room_id=room_id,
+ from_key=from_token.room_key,
+ to_key=to_room_key,
+ direction=pagin_config.direction,
+ limit=pagin_config.limit,
+ event_filter=event_filter,
+ )
- to_room_key = None
- if pagin_config.to_token:
- to_room_key = pagin_config.to_token.room_key
-
- # Initially fetch the events from the database. With any luck, we can return
- # these without blocking on backfill (handled below).
- events, next_key = await self.store.paginate_room_events(
- room_id=room_id,
- from_key=from_token.room_key,
- to_key=to_room_key,
- direction=pagin_config.direction,
- limit=pagin_config.limit,
- event_filter=event_filter,
+ if pagin_config.direction == Direction.BACKWARDS:
+ # We use a `Set` because there can be multiple events at a given depth
+ # and we only care about looking at the unique continum of depths to
+ # find gaps.
+ event_depths: Set[int] = {event.depth for event in events}
+ sorted_event_depths = sorted(event_depths)
+
+ # Inspect the depths of the returned events to see if there are any gaps
+ found_big_gap = False
+ number_of_gaps = 0
+ previous_event_depth = (
+ sorted_event_depths[0] if len(sorted_event_depths) > 0 else 0
)
-
- if pagin_config.direction == Direction.BACKWARDS:
- # We use a `Set` because there can be multiple events at a given depth
- # and we only care about looking at the unique continum of depths to
- # find gaps.
- event_depths: Set[int] = {event.depth for event in events}
- sorted_event_depths = sorted(event_depths)
-
- # Inspect the depths of the returned events to see if there are any gaps
- found_big_gap = False
- number_of_gaps = 0
- previous_event_depth = (
- sorted_event_depths[0] if len(sorted_event_depths) > 0 else 0
- )
- for event_depth in sorted_event_depths:
- # We don't expect a negative depth but we'll just deal with it in
- # any case by taking the absolute value to get the true gap between
- # any two integers.
- depth_gap = abs(event_depth - previous_event_depth)
- # A `depth_gap` of 1 is a normal continuous chain to the next event
- # (1 <-- 2 <-- 3) so anything larger indicates a missing event (it's
- # also possible there is no event at a given depth but we can't ever
- # know that for sure)
- if depth_gap > 1:
- number_of_gaps += 1
-
- # We only tolerate a small number single-event long gaps in the
- # returned events because those are most likely just events we've
- # failed to pull in the past. Anything longer than that is probably
- # a sign that we're missing a decent chunk of history and we should
- # try to backfill it.
- #
- # XXX: It's possible we could tolerate longer gaps if we checked
- # that a given events `prev_events` is one that has failed pull
- # attempts and we could just treat it like a dead branch of history
- # for now or at least something that we don't need the block the
- # client on to try pulling.
- #
- # XXX: If we had something like MSC3871 to indicate gaps in the
- # timeline to the client, we could also get away with any sized gap
- # and just have the client refetch the holes as they see fit.
- if depth_gap > 2:
- found_big_gap = True
- break
- previous_event_depth = event_depth
-
- # Backfill in the foreground if we found a big gap, have too many holes,
- # or we don't have enough events to fill the limit that the client asked
- # for.
- missing_too_many_events = (
- number_of_gaps > BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD
+ for event_depth in sorted_event_depths:
+ # We don't expect a negative depth but we'll just deal with it in
+ # any case by taking the absolute value to get the true gap between
+ # any two integers.
+ depth_gap = abs(event_depth - previous_event_depth)
+ # A `depth_gap` of 1 is a normal continuous chain to the next event
+ # (1 <-- 2 <-- 3) so anything larger indicates a missing event (it's
+ # also possible there is no event at a given depth but we can't ever
+ # know that for sure)
+ if depth_gap > 1:
+ number_of_gaps += 1
+
+ # We only tolerate a small number single-event long gaps in the
+ # returned events because those are most likely just events we've
+ # failed to pull in the past. Anything longer than that is probably
+ # a sign that we're missing a decent chunk of history and we should
+ # try to backfill it.
+ #
+ # XXX: It's possible we could tolerate longer gaps if we checked
+ # that a given events `prev_events` is one that has failed pull
+ # attempts and we could just treat it like a dead branch of history
+ # for now or at least something that we don't need the block the
+ # client on to try pulling.
+ #
+ # XXX: If we had something like MSC3871 to indicate gaps in the
+ # timeline to the client, we could also get away with any sized gap
+ # and just have the client refetch the holes as they see fit.
+ if depth_gap > 2:
+ found_big_gap = True
+ break
+ previous_event_depth = event_depth
+
+ # Backfill in the foreground if we found a big gap, have too many holes,
+ # or we don't have enough events to fill the limit that the client asked
+ # for.
+ missing_too_many_events = (
+ number_of_gaps > BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD
+ )
+ not_enough_events_to_fill_response = len(events) < pagin_config.limit
+ if (
+ found_big_gap
+ or missing_too_many_events
+ or not_enough_events_to_fill_response
+ ):
+ did_backfill = await self.hs.get_federation_handler().maybe_backfill(
+ room_id,
+ curr_topo,
+ limit=pagin_config.limit,
)
- not_enough_events_to_fill_response = len(events) < pagin_config.limit
- if (
- found_big_gap
- or missing_too_many_events
- or not_enough_events_to_fill_response
- ):
- did_backfill = (
- await self.hs.get_federation_handler().maybe_backfill(
- room_id,
- curr_topo,
- limit=pagin_config.limit,
- )
- )
- # If we did backfill something, refetch the events from the database to
- # catch anything new that might have been added since we last fetched.
- if did_backfill:
- events, next_key = await self.store.paginate_room_events(
- room_id=room_id,
- from_key=from_token.room_key,
- to_key=to_room_key,
- direction=pagin_config.direction,
- limit=pagin_config.limit,
- event_filter=event_filter,
- )
- else:
- # Otherwise, we can backfill in the background for eventual
- # consistency's sake but we don't need to block the client waiting
- # for a costly federation call and processing.
- run_as_background_process(
- "maybe_backfill_in_the_background",
- self.hs.get_federation_handler().maybe_backfill,
- room_id,
- curr_topo,
+ # If we did backfill something, refetch the events from the database to
+ # catch anything new that might have been added since we last fetched.
+ if did_backfill:
+ events, next_key = await self.store.paginate_room_events(
+ room_id=room_id,
+ from_key=from_token.room_key,
+ to_key=to_room_key,
+ direction=pagin_config.direction,
limit=pagin_config.limit,
+ event_filter=event_filter,
)
+ else:
+ # Otherwise, we can backfill in the background for eventual
+ # consistency's sake but we don't need to block the client waiting
+ # for a costly federation call and processing.
+ run_as_background_process(
+ "maybe_backfill_in_the_background",
+ self.hs.get_federation_handler().maybe_backfill,
+ room_id,
+ curr_topo,
+ limit=pagin_config.limit,
+ )
- next_token = from_token.copy_and_replace(StreamKeyType.ROOM, next_key)
+ next_token = from_token.copy_and_replace(StreamKeyType.ROOM, next_key)
# if no events are returned from pagination, that implies
# we have reached the end of the available events.
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 139f57cf86..3b88dc68ea 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -14,7 +14,9 @@
"""A replication client for use by synapse workers.
"""
import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, Optional, Set, Tuple
+
+from sortedcontainers import SortedList
from twisted.internet import defer
from twisted.internet.defer import Deferred
@@ -26,6 +28,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.streams import (
AccountDataStream,
+ CachesStream,
DeviceListsStream,
PushersStream,
PushRulesStream,
@@ -73,6 +76,7 @@ class ReplicationDataHandler:
self._instance_name = hs.get_instance_name()
self._typing_handler = hs.get_typing_handler()
self._state_storage_controller = hs.get_storage_controllers().state
+ self.auth = hs.get_auth()
self._notify_pushers = hs.config.worker.start_pushers
self._pusher_pool = hs.get_pusherpool()
@@ -84,7 +88,9 @@ class ReplicationDataHandler:
# Map from stream and instance to list of deferreds waiting for the stream to
# arrive at a particular position. The lists are sorted by stream position.
- self._streams_to_waiters: Dict[Tuple[str, str], List[Tuple[int, Deferred]]] = {}
+ self._streams_to_waiters: Dict[
+ Tuple[str, str], SortedList[Tuple[int, Deferred]]
+ ] = {}
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
@@ -218,6 +224,16 @@ class ReplicationDataHandler:
self._state_storage_controller.notify_event_un_partial_stated(
row.event_id
)
+ # invalidate the introspection token cache
+ elif stream_name == CachesStream.NAME:
+ for row in rows:
+ if row.cache_func == "introspection_token_invalidation":
+ if row.keys[0] is None:
+ # invalidate the whole cache
+ # mypy ignore - the token cache is defined on MSC3861DelegatedAuth
+ self.auth.invalidate_token_cache() # type: ignore[attr-defined]
+ else:
+ self.auth.invalidate_cached_tokens(row.keys) # type: ignore[attr-defined]
await self._presence_handler.process_replication_rows(
stream_name, instance_name, token, rows
@@ -226,7 +242,9 @@ class ReplicationDataHandler:
# Notify any waiting deferreds. The list is ordered by position so we
# just iterate through the list until we reach a position that is
# greater than the received row position.
- waiting_list = self._streams_to_waiters.get((stream_name, instance_name), [])
+ waiting_list = self._streams_to_waiters.get((stream_name, instance_name))
+ if not waiting_list:
+ return
# Index of first item with a position after the current token, i.e we
# have called all deferreds before this index. If not overwritten by
@@ -250,7 +268,7 @@ class ReplicationDataHandler:
# Drop all entries in the waiting list that were called in the above
# loop. (This maintains the order so no need to resort)
- waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
+ del waiting_list[:index_of_first_deferred_not_called]
for deferred in deferreds_to_callback:
try:
@@ -310,11 +328,10 @@ class ReplicationDataHandler:
)
waiting_list = self._streams_to_waiters.setdefault(
- (stream_name, instance_name), []
+ (stream_name, instance_name), SortedList(key=lambda t: t[0])
)
- waiting_list.append((position, deferred))
- waiting_list.sort(key=lambda t: t[0])
+ waiting_list.add((position, deferred))
# We measure here to get in flight counts and average waiting time.
with Measure(self._clock, "repl.wait_for_stream_position"):
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index fe8177ed4d..55e752fda8 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -47,6 +47,7 @@ from synapse.rest.admin.federation import (
ListDestinationsRestServlet,
)
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
+from synapse.rest.admin.oidc import OIDCTokenRevocationRestServlet
from synapse.rest.admin.registration_tokens import (
ListRegistrationTokensRestServlet,
NewRegistrationTokenRestServlet,
@@ -297,6 +298,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
BackgroundUpdateRestServlet(hs).register(http_server)
BackgroundUpdateStartJobRestServlet(hs).register(http_server)
ExperimentalFeaturesRestServlet(hs).register(http_server)
+ if hs.config.experimental.msc3861.enabled:
+ OIDCTokenRevocationRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(
diff --git a/synapse/rest/admin/oidc.py b/synapse/rest/admin/oidc.py
new file mode 100644
index 0000000000..64d2d40550
--- /dev/null
+++ b/synapse/rest/admin/oidc.py
@@ -0,0 +1,55 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from http import HTTPStatus
+from typing import TYPE_CHECKING, Dict, Tuple
+
+from synapse.http.servlet import RestServlet
+from synapse.http.site import SynapseRequest
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+class OIDCTokenRevocationRestServlet(RestServlet):
+ """
+ Delete a given token introspection response - identified by the `jti` field - from the
+ introspection token cache when a token is revoked at the authorizing server
+ """
+
+ PATTERNS = admin_patterns("/OIDC_token_revocation/(?P<token_id>[^/]*)")
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ auth = hs.get_auth()
+
+ # If this endpoint is loaded then we must have enabled delegated auth.
+ from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth
+
+ assert isinstance(auth, MSC3861DelegatedAuth)
+
+ self.auth = auth
+ self.store = hs.get_datastores().main
+
+ async def on_DELETE(
+ self, request: SynapseRequest, token_id: str
+ ) -> Tuple[HTTPStatus, Dict]:
+ await assert_requester_is_admin(self.auth, request)
+
+ self.auth._token_cache.invalidate(token_id)
+
+ # make sure we invalidate the cache on any workers
+ await self.store.stream_introspection_token_invalidation((token_id,))
+
+ return HTTPStatus.OK, {}
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 04d9ef25b7..240e6254b0 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -109,6 +109,8 @@ class UsersRestServletV2(RestServlet):
)
deactivated = parse_boolean(request, "deactivated", default=False)
+ admins = parse_boolean(request, "admins")
+
# If support for MSC3866 is not enabled, apply no filtering based on the
# `approved` column.
if self._msc3866_enabled:
@@ -146,6 +148,7 @@ class UsersRestServletV2(RestServlet):
name,
guests,
deactivated,
+ admins,
order_by,
direction,
approved,
diff --git a/synapse/server.py b/synapse/server.py
index e753ff0377..7cdd3ea3c2 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -142,6 +142,7 @@ from synapse.util.distributor import Distributor
from synapse.util.macaroons import MacaroonGenerator
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import random_string
+from synapse.util.task_scheduler import TaskScheduler
logger = logging.getLogger(__name__)
@@ -360,6 +361,7 @@ class HomeServer(metaclass=abc.ABCMeta):
"""
for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP:
getattr(self, "get_" + i + "_handler")()
+ self.get_task_scheduler()
def get_reactor(self) -> ISynapseReactor:
"""
@@ -912,6 +914,9 @@ class HomeServer(metaclass=abc.ABCMeta):
"""Usage metrics shared between phone home stats and the prometheus exporter."""
return CommonUsageMetricsManager(self)
- @cache_in_self
def get_worker_locks_handler(self) -> WorkerLocksHandler:
return WorkerLocksHandler(self)
+
+ @cache_in_self
+ def get_task_scheduler(self) -> TaskScheduler:
+ return TaskScheduler(self)
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index be67d1ff22..a85633efcd 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -70,6 +70,7 @@ from .state import StateStore
from .stats import StatsStore
from .stream import StreamWorkerStore
from .tags import TagsStore
+from .task_scheduler import TaskSchedulerWorkerStore
from .transactions import TransactionWorkerStore
from .ui_auth import UIAuthStore
from .user_directory import UserDirectoryStore
@@ -127,6 +128,7 @@ class DataStore(
CacheInvalidationWorkerStore,
LockStore,
SessionStore,
+ TaskSchedulerWorkerStore,
):
def __init__(
self,
@@ -168,6 +170,7 @@ class DataStore(
name: Optional[str] = None,
guests: bool = True,
deactivated: bool = False,
+ admins: Optional[bool] = None,
order_by: str = UserSortOrder.NAME.value,
direction: Direction = Direction.FORWARDS,
approved: bool = True,
@@ -184,6 +187,9 @@ class DataStore(
name: search for local part of user_id or display name
guests: whether to in include guest users
deactivated: whether to include deactivated users
+ admins: Optional flag to filter admins. If true, only admins are queried.
+ if false, admins are excluded from the query. When it is
+ none (the default), both admins and none-admins are queried.
order_by: the sort order of the returned list
direction: sort ascending or descending
approved: whether to include approved users
@@ -220,6 +226,12 @@ class DataStore(
if not deactivated:
filters.append("deactivated = 0")
+ if admins is not None:
+ if admins:
+ filters.append("admin = 1")
+ else:
+ filters.append("admin = 0")
+
if not approved:
# We ignore NULL values for the approved flag because these should only
# be already existing users that we consider as already approved.
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 2fbd389c71..18905e07b6 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -584,6 +584,19 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
else:
return 0
+ async def stream_introspection_token_invalidation(
+ self, key: Tuple[Optional[str]]
+ ) -> None:
+ """
+ Stream an invalidation request for the introspection token cache to workers
+
+ Args:
+ key: token_id of the introspection token to remove from the cache
+ """
+ await self.send_invalidation_to_replication(
+ "introspection_token_invalidation", key
+ )
+
@wrap_as_background_process("clean_up_old_cache_invalidations")
async def _clean_up_cache_invalidation_wrapper(self) -> None:
"""
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index e4162f846b..fa69a4a298 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -33,6 +33,7 @@ from typing_extensions import Literal
from synapse.api.constants import EduTypes
from synapse.api.errors import Codes, StoreError
+from synapse.config.homeserver import HomeServerConfig
from synapse.logging.opentracing import (
get_active_span_text_map,
set_tag,
@@ -1663,6 +1664,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.device_id_exists_cache: LruCache[
Tuple[str, str], Literal[True]
] = LruCache(cache_name="device_id_exists", max_size=10000)
+ self.config: HomeServerConfig = hs.config
async def store_device(
self,
@@ -1784,6 +1786,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id))
+ # TODO: don't nuke the entire cache once there is a way to associate
+ # device_id -> introspection_token
+ if self.config.experimental.msc3861.enabled:
+ # mypy ignore - the token cache is defined on MSC3861DelegatedAuth
+ self.auth._token_cache.invalidate_all() # type: ignore[attr-defined]
+ await self.stream_introspection_token_invalidation((None,))
+
async def update_device(
self, user_id: str, device_id: str, new_display_name: Optional[str] = None
) -> None:
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 534dc32413..fab7008a8f 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -452,33 +452,56 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# sets.
seen_chains: Set[int] = set()
- sql = """
- SELECT event_id, chain_id, sequence_number
- FROM event_auth_chains
- WHERE %s
- """
- for batch in batch_iter(initial_events, 1000):
- clause, args = make_in_list_sql_clause(
- txn.database_engine, "event_id", batch
- )
- txn.execute(sql % (clause,), args)
+ # Fetch the chain cover index for the initial set of events we're
+ # considering.
+ def fetch_chain_info(events_to_fetch: Collection[str]) -> None:
+ sql = """
+ SELECT event_id, chain_id, sequence_number
+ FROM event_auth_chains
+ WHERE %s
+ """
+ for batch in batch_iter(events_to_fetch, 1000):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "event_id", batch
+ )
+ txn.execute(sql % (clause,), args)
- for event_id, chain_id, sequence_number in txn:
- chain_info[event_id] = (chain_id, sequence_number)
- seen_chains.add(chain_id)
- chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id
+ for event_id, chain_id, sequence_number in txn:
+ chain_info[event_id] = (chain_id, sequence_number)
+ seen_chains.add(chain_id)
+ chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id
+
+ fetch_chain_info(initial_events)
# Check that we actually have a chain ID for all the events.
events_missing_chain_info = initial_events.difference(chain_info)
+
+ # The result set to return, i.e. the auth chain difference.
+ result: Set[str] = set()
+
if events_missing_chain_info:
- # This can happen due to e.g. downgrade/upgrade of the server. We
- # raise an exception and fall back to the previous algorithm.
- logger.info(
- "Unexpectedly found that events don't have chain IDs in room %s: %s",
+ # For some reason we have events we haven't calculated the chain
+ # index for, so we need to handle those separately. This should only
+ # happen for older rooms where the server doesn't have all the auth
+ # events.
+ result = self._fixup_auth_chain_difference_sets(
+ txn,
room_id,
- events_missing_chain_info,
+ state_sets=state_sets,
+ events_missing_chain_info=events_missing_chain_info,
+ events_that_have_chain_index=chain_info,
)
- raise _NoChainCoverIndex(room_id)
+
+ # We now need to refetch any events that we have added to the state
+ # sets.
+ new_events_to_fetch = {
+ event_id
+ for state_set in state_sets
+ for event_id in state_set
+ if event_id not in initial_events
+ }
+
+ fetch_chain_info(new_events_to_fetch)
# Corresponds to `state_sets`, except as a map from chain ID to max
# sequence number reachable from the state set.
@@ -487,8 +510,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
chains: Dict[int, int] = {}
set_to_chain.append(chains)
- for event_id in state_set:
- chain_id, seq_no = chain_info[event_id]
+ for state_id in state_set:
+ chain_id, seq_no = chain_info[state_id]
chains[chain_id] = max(seq_no, chains.get(chain_id, 0))
@@ -532,7 +555,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# from *any* state set and the minimum sequence number reachable from
# *all* state sets. Events in that range are in the auth chain
# difference.
- result = set()
# Mapping from chain ID to the range of sequence numbers that should be
# pulled from the database.
@@ -588,6 +610,122 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return result
+ def _fixup_auth_chain_difference_sets(
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ state_sets: List[Set[str]],
+ events_missing_chain_info: Set[str],
+ events_that_have_chain_index: Collection[str],
+ ) -> Set[str]:
+ """Helper for `_get_auth_chain_difference_using_cover_index_txn` to
+ handle the case where we haven't calculated the chain cover index for
+ all events.
+
+ This modifies `state_sets` so that they only include events that have a
+ chain cover index, and returns a set of event IDs that are part of the
+ auth difference.
+ """
+
+ # This works similarly to the handling of unpersisted events in
+ # `synapse.state.v2_get_auth_chain_difference`. We uses the observation
+ # that if you can split the set of events into two classes X and Y,
+ # where no events in Y have events in X in their auth chain, then we can
+ # calculate the auth difference by considering X and Y separately.
+ #
+ # We do this in three steps:
+ # 1. Compute the set of events without chain cover index belonging to
+ # the auth difference.
+ # 2. Replacing the un-indexed events in the state_sets with their auth
+ # events, recursively, until the state_sets contain only indexed
+ # events. We can then calculate the auth difference of those state
+ # sets using the chain cover index.
+ # 3. Add the results of 1 and 2 together.
+
+ # By construction we know that all events that we haven't persisted the
+ # chain cover index for are contained in
+ # `event_auth_chain_to_calculate`, so we pull out the events from those
+ # rather than doing recursive queries to walk the auth chain.
+ #
+ # We pull out those events with their auth events, which gives us enough
+ # information to construct the auth chain of an event up to auth events
+ # that have the chain cover index.
+ sql = """
+ SELECT tc.event_id, ea.auth_id, eac.chain_id IS NOT NULL
+ FROM event_auth_chain_to_calculate AS tc
+ LEFT JOIN event_auth AS ea USING (event_id)
+ LEFT JOIN event_auth_chains AS eac ON (ea.auth_id = eac.event_id)
+ WHERE tc.room_id = ?
+ """
+ txn.execute(sql, (room_id,))
+ event_to_auth_ids: Dict[str, Set[str]] = {}
+ events_that_have_chain_index = set(events_that_have_chain_index)
+ for event_id, auth_id, auth_id_has_chain in txn:
+ s = event_to_auth_ids.setdefault(event_id, set())
+ if auth_id is not None:
+ s.add(auth_id)
+ if auth_id_has_chain:
+ events_that_have_chain_index.add(auth_id)
+
+ if events_missing_chain_info - event_to_auth_ids.keys():
+ # Uh oh, we somehow haven't correctly done the chain cover index,
+ # bail and fall back to the old method.
+ logger.info(
+ "Unexpectedly found that events don't have chain IDs in room %s: %s",
+ room_id,
+ events_missing_chain_info - event_to_auth_ids.keys(),
+ )
+ raise _NoChainCoverIndex(room_id)
+
+ # Create a map from event IDs we care about to their partial auth chain.
+ event_id_to_partial_auth_chain: Dict[str, Set[str]] = {}
+ for event_id, auth_ids in event_to_auth_ids.items():
+ if not any(event_id in state_set for state_set in state_sets):
+ continue
+
+ processing = set(auth_ids)
+ to_add = set()
+ while processing:
+ auth_id = processing.pop()
+ to_add.add(auth_id)
+
+ sub_auth_ids = event_to_auth_ids.get(auth_id)
+ if sub_auth_ids is None:
+ continue
+
+ processing.update(sub_auth_ids - to_add)
+
+ event_id_to_partial_auth_chain[event_id] = to_add
+
+ # Now we do two things:
+ # 1. Update the state sets to only include indexed events; and
+ # 2. Create a new list containing the auth chains of the un-indexed
+ # events
+ unindexed_state_sets: List[Set[str]] = []
+ for state_set in state_sets:
+ unindexed_state_set = set()
+ for event_id, auth_chain in event_id_to_partial_auth_chain.items():
+ if event_id not in state_set:
+ continue
+
+ unindexed_state_set.add(event_id)
+
+ state_set.discard(event_id)
+ state_set.difference_update(auth_chain)
+ for auth_id in auth_chain:
+ if auth_id in events_that_have_chain_index:
+ state_set.add(auth_id)
+ else:
+ unindexed_state_set.add(auth_id)
+
+ unindexed_state_sets.append(unindexed_state_set)
+
+ # Calculate and return the auth difference of the un-indexed events.
+ union = unindexed_state_sets[0].union(*unindexed_state_sets[1:])
+ intersection = unindexed_state_sets[0].intersection(*unindexed_state_sets[1:])
+
+ return union - intersection
+
def _get_auth_chain_difference_txn(
self, txn: LoggingTransaction, state_sets: List[Set[str]]
) -> Set[str]:
diff --git a/synapse/storage/databases/main/task_scheduler.py b/synapse/storage/databases/main/task_scheduler.py
new file mode 100644
index 0000000000..1fb3180c3c
--- /dev/null
+++ b/synapse/storage/databases/main/task_scheduler.py
@@ -0,0 +1,202 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+ make_in_list_sql_clause,
+)
+from synapse.types import JsonDict, JsonMapping, ScheduledTask, TaskStatus
+from synapse.util import json_encoder
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+class TaskSchedulerWorkerStore(SQLBaseStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
+
+ @staticmethod
+ def _convert_row_to_task(row: Dict[str, Any]) -> ScheduledTask:
+ row["status"] = TaskStatus(row["status"])
+ if row["params"] is not None:
+ row["params"] = db_to_json(row["params"])
+ if row["result"] is not None:
+ row["result"] = db_to_json(row["result"])
+ return ScheduledTask(**row)
+
+ async def get_scheduled_tasks(
+ self,
+ *,
+ actions: Optional[List[str]] = None,
+ resource_id: Optional[str] = None,
+ statuses: Optional[List[TaskStatus]] = None,
+ max_timestamp: Optional[int] = None,
+ ) -> List[ScheduledTask]:
+ """Get a list of scheduled tasks from the DB.
+
+ Args:
+ actions: Limit the returned tasks to those specific action names
+ resource_id: Limit the returned tasks to the specific resource id, if specified
+ statuses: Limit the returned tasks to the specific statuses
+ max_timestamp: Limit the returned tasks to the ones that have
+ a timestamp inferior to the specified one
+
+ Returns: a list of `ScheduledTask`, ordered by increasing timestamps
+ """
+
+ def get_scheduled_tasks_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ clauses: List[str] = []
+ args: List[Any] = []
+ if resource_id:
+ clauses.append("resource_id = ?")
+ args.append(resource_id)
+ if actions is not None:
+ clause, temp_args = make_in_list_sql_clause(
+ txn.database_engine, "action", actions
+ )
+ clauses.append(clause)
+ args.extend(temp_args)
+ if statuses is not None:
+ clause, temp_args = make_in_list_sql_clause(
+ txn.database_engine, "status", statuses
+ )
+ clauses.append(clause)
+ args.extend(temp_args)
+ if max_timestamp is not None:
+ clauses.append("timestamp <= ?")
+ args.append(max_timestamp)
+
+ sql = "SELECT * FROM scheduled_tasks"
+ if clauses:
+ sql = sql + " WHERE " + " AND ".join(clauses)
+
+ sql = sql + "ORDER BY timestamp"
+
+ txn.execute(sql, args)
+ return self.db_pool.cursor_to_dict(txn)
+
+ rows = await self.db_pool.runInteraction(
+ "get_scheduled_tasks", get_scheduled_tasks_txn
+ )
+ return [TaskSchedulerWorkerStore._convert_row_to_task(row) for row in rows]
+
+ async def insert_scheduled_task(self, task: ScheduledTask) -> None:
+ """Insert a specified `ScheduledTask` in the DB.
+
+ Args:
+ task: the `ScheduledTask` to insert
+ """
+ await self.db_pool.simple_insert(
+ "scheduled_tasks",
+ {
+ "id": task.id,
+ "action": task.action,
+ "status": task.status,
+ "timestamp": task.timestamp,
+ "resource_id": task.resource_id,
+ "params": None
+ if task.params is None
+ else json_encoder.encode(task.params),
+ "result": None
+ if task.result is None
+ else json_encoder.encode(task.result),
+ "error": task.error,
+ },
+ desc="insert_scheduled_task",
+ )
+
+ async def update_scheduled_task(
+ self,
+ id: str,
+ timestamp: int,
+ *,
+ status: Optional[TaskStatus] = None,
+ result: Optional[JsonMapping] = None,
+ error: Optional[str] = None,
+ ) -> bool:
+ """Update a scheduled task in the DB with some new value(s).
+
+ Args:
+ id: id of the `ScheduledTask` to update
+ timestamp: new timestamp of the task
+ status: new status of the task
+ result: new result of the task
+ error: new error of the task
+
+ Returns: `False` if no matching row was found, `True` otherwise
+ """
+ updatevalues: JsonDict = {"timestamp": timestamp}
+ if status is not None:
+ updatevalues["status"] = status
+ if result is not None:
+ updatevalues["result"] = json_encoder.encode(result)
+ if error is not None:
+ updatevalues["error"] = error
+ nb_rows = await self.db_pool.simple_update(
+ "scheduled_tasks",
+ {"id": id},
+ updatevalues,
+ desc="update_scheduled_task",
+ )
+ return nb_rows > 0
+
+ async def get_scheduled_task(self, id: str) -> Optional[ScheduledTask]:
+ """Get a specific `ScheduledTask` from its id.
+
+ Args:
+ id: the id of the task to retrieve
+
+ Returns: the task if available, `None` otherwise
+ """
+ row = await self.db_pool.simple_select_one(
+ table="scheduled_tasks",
+ keyvalues={"id": id},
+ retcols=(
+ "id",
+ "action",
+ "status",
+ "timestamp",
+ "resource_id",
+ "params",
+ "result",
+ "error",
+ ),
+ allow_none=True,
+ desc="get_scheduled_task",
+ )
+
+ return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None
+
+ async def delete_scheduled_task(self, id: str) -> None:
+ """Delete a specific task from its id.
+
+ Args:
+ id: the id of the task to delete
+ """
+ await self.db_pool.simple_delete(
+ "scheduled_tasks",
+ keyvalues={"id": id},
+ desc="delete_scheduled_task",
+ )
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index c3bd36efc9..48e4b0ba3c 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -242,6 +242,8 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
) -> None:
# Upsert retry time interval if retry_interval is zero (i.e. we're
# resetting it) or greater than the existing retry interval.
+ # We also upsert when the new retry interval is the same as the existing one,
+ # since it will be the case when `destination_max_retry_interval` is reached.
#
# WARNING: This is executed in autocommit, so we shouldn't add any more
# SQL calls in here (without being very careful).
@@ -257,7 +259,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
WHERE
EXCLUDED.retry_interval = 0
OR destinations.retry_interval IS NULL
- OR destinations.retry_interval < EXCLUDED.retry_interval
+ OR destinations.retry_interval <= EXCLUDED.retry_interval
"""
txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 7de9949a5b..649d3c8e9f 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -113,6 +113,7 @@ Changes in SCHEMA_VERSION = 79
Changes in SCHEMA_VERSION = 80
- The event_txn_id_device_id is always written to for new events.
+ - Add tables for the task scheduler.
"""
diff --git a/synapse/storage/schema/main/delta/80/02_read_write_locks_unlogged.sql.postgres b/synapse/storage/schema/main/delta/80/02_read_write_locks_unlogged.sql.postgres
new file mode 100644
index 0000000000..5b5dbf2687
--- /dev/null
+++ b/synapse/storage/schema/main/delta/80/02_read_write_locks_unlogged.sql.postgres
@@ -0,0 +1,30 @@
+/* Copyright 2023 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Mark the worker_read_write_locks* tables as UNLOGGED, to increase
+-- performance. This means that we don't replicate the tables, and they get
+-- truncated on a crash. This is acceptable as a) in those cases it's likely
+-- that Synapse needs to be stopped/restarted anyway, and b) the locks are
+-- considered best-effort anyway.
+
+-- We need to remove and recreate the circular foreign key references, as
+-- UNLOGGED tables can't reference normal tables.
+ALTER TABLE worker_read_write_locks_mode DROP CONSTRAINT IF EXISTS worker_read_write_locks_mode_foreign;
+
+ALTER TABLE worker_read_write_locks SET UNLOGGED;
+ALTER TABLE worker_read_write_locks_mode SET UNLOGGED;
+
+ALTER TABLE worker_read_write_locks_mode ADD CONSTRAINT worker_read_write_locks_mode_foreign
+ FOREIGN KEY (lock_name, lock_key, token) REFERENCES worker_read_write_locks(lock_name, lock_key, token) DEFERRABLE INITIALLY DEFERRED;
diff --git a/synapse/storage/schema/main/delta/80/02_scheduled_tasks.sql b/synapse/storage/schema/main/delta/80/02_scheduled_tasks.sql
new file mode 100644
index 0000000000..286d109ed7
--- /dev/null
+++ b/synapse/storage/schema/main/delta/80/02_scheduled_tasks.sql
@@ -0,0 +1,28 @@
+/* Copyright 2023 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- cf ScheduledTask docstring for the meaning of the fields.
+CREATE TABLE IF NOT EXISTS scheduled_tasks(
+ id TEXT PRIMARY KEY,
+ action TEXT NOT NULL,
+ status TEXT NOT NULL,
+ timestamp BIGINT NOT NULL,
+ resource_id TEXT,
+ params TEXT,
+ result TEXT,
+ error TEXT
+);
+
+CREATE INDEX IF NOT EXISTS scheduled_tasks_status ON scheduled_tasks(status);
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 073f682aca..e750417189 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -15,6 +15,7 @@
import abc
import re
import string
+from enum import Enum
from typing import (
TYPE_CHECKING,
AbstractSet,
@@ -969,3 +970,41 @@ class UserProfile(TypedDict):
class RetentionPolicy:
min_lifetime: Optional[int] = None
max_lifetime: Optional[int] = None
+
+
+class TaskStatus(str, Enum):
+ """Status of a scheduled task"""
+
+ # Task is scheduled but not active
+ SCHEDULED = "scheduled"
+ # Task is active and probably running, and if not
+ # will be run on next scheduler loop run
+ ACTIVE = "active"
+ # Task has completed successfully
+ COMPLETE = "complete"
+ # Task is over and either returned a failed status, or had an exception
+ FAILED = "failed"
+
+
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class ScheduledTask:
+ """Description of a scheduled task"""
+
+ # Id used to identify the task
+ id: str
+ # Name of the action to be run by this task
+ action: str
+ # Current status of this task
+ status: TaskStatus
+ # If the status is SCHEDULED then this represents when it should be launched,
+ # otherwise it represents the last time this task got a change of state.
+ # In milliseconds since epoch in system time timezone, usually UTC.
+ timestamp: int
+ # Optionally bind a task to some resource id for easy retrieval
+ resource_id: Optional[str]
+ # Optional parameters that will be passed to the function ran by the task
+ params: Optional[JsonMapping]
+ # Optional result that can be updated by the running task
+ result: Optional[JsonMapping]
+ # Optional error that should be assigned a value when the status is FAILED
+ error: Optional[str]
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 01ad02af67..9a3e10ddee 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -140,6 +140,20 @@ class ExpiringCache(Generic[KT, VT]):
return value.value
+ def invalidate(self, key: KT) -> None:
+ """
+ Remove the given key from the cache.
+ """
+
+ value = self._cache.pop(key, None)
+ if value:
+ if self.iterable:
+ self.metrics.inc_evictions(
+ EvictionReason.invalidation, len(value.value)
+ )
+ else:
+ self.metrics.inc_evictions(EvictionReason.invalidation)
+
def __contains__(self, key: KT) -> bool:
return key in self._cache
@@ -193,6 +207,14 @@ class ExpiringCache(Generic[KT, VT]):
len(self),
)
+ def invalidate_all(self) -> None:
+ """
+ Remove all items from the cache.
+ """
+ keys = set(self._cache.keys())
+ for key in keys:
+ self._cache.pop(key)
+
def __len__(self) -> int:
if self.iterable:
return sum(len(entry.value) for entry in self._cache.values())
diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py
new file mode 100644
index 0000000000..773a8327f6
--- /dev/null
+++ b/synapse/util/task_scheduler.py
@@ -0,0 +1,364 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set, Tuple
+
+from prometheus_client import Gauge
+
+from twisted.python.failure import Failure
+
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import JsonMapping, ScheduledTask, TaskStatus
+from synapse.util.stringutils import random_string
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+running_tasks_gauge = Gauge(
+ "synapse_scheduler_running_tasks",
+ "The number of concurrent running tasks handled by the TaskScheduler",
+)
+
+
+class TaskScheduler:
+ """
+ This is a simple task sheduler aimed at resumable tasks: usually we use `run_in_background`
+ to launch a background task, or Twisted `deferLater` if we want to do so later on.
+
+ The problem with that is that the tasks will just stop and never be resumed if synapse
+ is stopped for whatever reason.
+
+ How this works:
+ - A function mapped to a named action should first be registered with `register_action`.
+ This function will be called when trying to resuming tasks after a synapse shutdown,
+ so this registration should happen when synapse is initialised, NOT right before scheduling
+ a task.
+ - A task can then be launched using this named action with `schedule_task`. A `params` dict
+ can be passed, and it will be available to the registered function when launched. This task
+ can be launch either now-ish, or later on by giving a `timestamp` parameter.
+
+ The function may call `update_task` at any time to update the `result` of the task,
+ and this can be used to resume the task at a specific point and/or to convey a result to
+ the code launching the task.
+ You can also specify the `result` (and/or an `error`) when returning from the function.
+
+ The reconciliation loop runs every 5 mns, so this is not a precise scheduler. When wanting
+ to launch now, the launch will still not happen before the next loop run.
+
+ Tasks will be run on the worker specified with `run_background_tasks_on` config,
+ or the main one by default.
+ There is a limit of 10 concurrent tasks, so tasks may be delayed if the pool is already
+ full. In this regard, please take great care that scheduled tasks can actually finished.
+ For now there is no mechanism to stop a running task if it is stuck.
+ """
+
+ # Precision of the scheduler, evaluation of tasks to run will only happen
+ # every `SCHEDULE_INTERVAL_MS` ms
+ SCHEDULE_INTERVAL_MS = 1 * 60 * 1000 # 1mn
+ # Time before a complete or failed task is deleted from the DB
+ KEEP_TASKS_FOR_MS = 7 * 24 * 60 * 60 * 1000 # 1 week
+ # Maximum number of tasks that can run at the same time
+ MAX_CONCURRENT_RUNNING_TASKS = 10
+ # Time from the last task update after which we will log a warning
+ LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000 # 24hrs
+
+ def __init__(self, hs: "HomeServer"):
+ self._store = hs.get_datastores().main
+ self._clock = hs.get_clock()
+ self._running_tasks: Set[str] = set()
+ # A map between action names and their registered function
+ self._actions: Dict[
+ str,
+ Callable[
+ [ScheduledTask, bool],
+ Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]],
+ ],
+ ] = {}
+ self._run_background_tasks = hs.config.worker.run_background_tasks
+
+ if self._run_background_tasks:
+ self._clock.looping_call(
+ run_as_background_process,
+ TaskScheduler.SCHEDULE_INTERVAL_MS,
+ "handle_scheduled_tasks",
+ self._handle_scheduled_tasks,
+ )
+
+ def register_action(
+ self,
+ function: Callable[
+ [ScheduledTask, bool],
+ Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]],
+ ],
+ action_name: str,
+ ) -> None:
+ """Register a function to be executed when an action is scheduled with
+ the specified action name.
+
+ Actions need to be registered as early as possible so that a resumed action
+ can find its matching function. It's usually better to NOT do that right before
+ calling `schedule_task` but rather in an `__init__` method.
+
+ Args:
+ function: The function to be executed for this action. The parameters
+ passed to the function when launched are the `ScheduledTask` being run,
+ and a `first_launch` boolean to signal if it's a resumed task or the first
+ launch of it. The function should return a tuple of new `status`, `result`
+ and `error` as specified in `ScheduledTask`.
+ action_name: The name of the action to be associated with the function
+ """
+ self._actions[action_name] = function
+
+ async def schedule_task(
+ self,
+ action: str,
+ *,
+ resource_id: Optional[str] = None,
+ timestamp: Optional[int] = None,
+ params: Optional[JsonMapping] = None,
+ ) -> str:
+ """Schedule a new potentially resumable task. A function matching the specified
+ `action` should have been previously registered with `register_action`.
+
+ Args:
+ action: the name of a previously registered action
+ resource_id: a task can be associated with a resource id to facilitate
+ getting all tasks associated with a specific resource
+ timestamp: if `None`, the task will be launched as soon as possible, otherwise it
+ will be launch as soon as possible after the `timestamp` value.
+ Note that this scheduler is not meant to be precise, and the scheduling
+ could be delayed if too many tasks are already running
+ params: a set of parameters that can be easily accessed from inside the
+ executed function
+
+ Returns:
+ The id of the scheduled task
+ """
+ if action not in self._actions:
+ raise Exception(
+ f"No function associated with action {action} of the scheduled task"
+ )
+
+ if timestamp is None or timestamp < self._clock.time_msec():
+ timestamp = self._clock.time_msec()
+
+ task = ScheduledTask(
+ random_string(16),
+ action,
+ TaskStatus.SCHEDULED,
+ timestamp,
+ resource_id,
+ params,
+ result=None,
+ error=None,
+ )
+ await self._store.insert_scheduled_task(task)
+
+ return task.id
+
+ async def update_task(
+ self,
+ id: str,
+ *,
+ timestamp: Optional[int] = None,
+ status: Optional[TaskStatus] = None,
+ result: Optional[JsonMapping] = None,
+ error: Optional[str] = None,
+ ) -> bool:
+ """Update some task associated values. This is exposed publically so it can
+ be used inside task functions, mainly to update the result and be able to
+ resume a task at a specific step after a restart of synapse.
+
+ It can also be used to stage a task, by setting the `status` to `SCHEDULED` with
+ a new timestamp.
+
+ The `status` can only be set to `ACTIVE` or `SCHEDULED`, `COMPLETE` and `FAILED`
+ are terminal status and can only be set by returning it in the function.
+
+ Args:
+ id: the id of the task to update
+ timestamp: useful to schedule a new stage of the task at a later date
+ status: the new `TaskStatus` of the task
+ result: the new result of the task
+ error: the new error of the task
+ """
+ if status == TaskStatus.COMPLETE or status == TaskStatus.FAILED:
+ raise Exception(
+ "update_task can't be called with a FAILED or COMPLETE status"
+ )
+
+ if timestamp is None:
+ timestamp = self._clock.time_msec()
+ return await self._store.update_scheduled_task(
+ id,
+ timestamp,
+ status=status,
+ result=result,
+ error=error,
+ )
+
+ async def get_task(self, id: str) -> Optional[ScheduledTask]:
+ """Get a specific task description by id.
+
+ Args:
+ id: the id of the task to retrieve
+
+ Returns:
+ The task information or `None` if it doesn't exist or it has
+ already been removed because it's too old.
+ """
+ return await self._store.get_scheduled_task(id)
+
+ async def get_tasks(
+ self,
+ *,
+ actions: Optional[List[str]] = None,
+ resource_id: Optional[str] = None,
+ statuses: Optional[List[TaskStatus]] = None,
+ max_timestamp: Optional[int] = None,
+ ) -> List[ScheduledTask]:
+ """Get a list of tasks. Returns all the tasks if no args is provided.
+
+ If an arg is `None` all tasks matching the other args will be selected.
+ If an arg is an empty list, the corresponding value of the task needs
+ to be `None` to be selected.
+
+ Args:
+ actions: Limit the returned tasks to those specific action names
+ resource_id: Limit the returned tasks to the specific resource id, if specified
+ statuses: Limit the returned tasks to the specific statuses
+ max_timestamp: Limit the returned tasks to the ones that have
+ a timestamp inferior to the specified one
+
+ Returns
+ A list of `ScheduledTask`, ordered by increasing timestamps
+ """
+ return await self._store.get_scheduled_tasks(
+ actions=actions,
+ resource_id=resource_id,
+ statuses=statuses,
+ max_timestamp=max_timestamp,
+ )
+
+ async def delete_task(self, id: str) -> None:
+ """Delete a task. Running tasks can't be deleted.
+
+ Can only be called from the worker handling the task scheduling.
+
+ Args:
+ id: id of the task to delete
+ """
+ if self.task_is_running(id):
+ raise Exception(f"Task {id} is currently running and can't be deleted")
+ await self._store.delete_scheduled_task(id)
+
+ def task_is_running(self, id: str) -> bool:
+ """Check if a task is currently running.
+
+ Can only be called from the worker handling the task scheduling.
+
+ Args:
+ id: id of the task to check
+ """
+ assert self._run_background_tasks
+ return id in self._running_tasks
+
+ async def _handle_scheduled_tasks(self) -> None:
+ """Main loop taking care of launching tasks and cleaning up old ones."""
+ await self._launch_scheduled_tasks()
+ await self._clean_scheduled_tasks()
+
+ async def _launch_scheduled_tasks(self) -> None:
+ """Retrieve and launch scheduled tasks that should be running at that time."""
+ for task in await self.get_tasks(statuses=[TaskStatus.ACTIVE]):
+ if not self.task_is_running(task.id):
+ if (
+ len(self._running_tasks)
+ < TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS
+ ):
+ await self._launch_task(task, first_launch=False)
+ else:
+ if (
+ self._clock.time_msec()
+ > task.timestamp + TaskScheduler.LAST_UPDATE_BEFORE_WARNING_MS
+ ):
+ logger.warn(
+ f"Task {task.id} (action {task.action}) has seen no update for more than 24h and may be stuck"
+ )
+ for task in await self.get_tasks(
+ statuses=[TaskStatus.SCHEDULED], max_timestamp=self._clock.time_msec()
+ ):
+ if (
+ not self.task_is_running(task.id)
+ and len(self._running_tasks)
+ < TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS
+ ):
+ await self._launch_task(task, first_launch=True)
+
+ running_tasks_gauge.set(len(self._running_tasks))
+
+ async def _clean_scheduled_tasks(self) -> None:
+ """Clean old complete or failed jobs to avoid clutter the DB."""
+ for task in await self._store.get_scheduled_tasks(
+ statuses=[TaskStatus.FAILED, TaskStatus.COMPLETE]
+ ):
+ # FAILED and COMPLETE tasks should never be running
+ assert not self.task_is_running(task.id)
+ if (
+ self._clock.time_msec()
+ > task.timestamp + TaskScheduler.KEEP_TASKS_FOR_MS
+ ):
+ await self._store.delete_scheduled_task(task.id)
+
+ async def _launch_task(self, task: ScheduledTask, first_launch: bool) -> None:
+ """Launch a scheduled task now.
+
+ Args:
+ task: the task to launch
+ first_launch: `True` if it's the first time is launched, `False` otherwise
+ """
+ assert task.action in self._actions
+
+ function = self._actions[task.action]
+
+ async def wrapper() -> None:
+ try:
+ (status, result, error) = await function(task, first_launch)
+ except Exception:
+ f = Failure()
+ logger.error(
+ f"scheduled task {task.id} failed",
+ exc_info=(f.type, f.value, f.getTracebackObject()),
+ )
+ status = TaskStatus.FAILED
+ result = None
+ error = f.getErrorMessage()
+
+ await self._store.update_scheduled_task(
+ task.id,
+ self._clock.time_msec(),
+ status=status,
+ result=result,
+ error=error,
+ )
+ self._running_tasks.remove(task.id)
+
+ self._running_tasks.add(task.id)
+ await self.update_task(task.id, status=TaskStatus.ACTIVE)
+ description = f"{task.id}-{task.action}"
+ run_as_background_process(description, wrapper)
|