From a00462dd9927558532b030593f8914ade53b7214 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 9 May 2022 12:31:14 +0100 Subject: Implement cancellation support/protection for module callbacks (#12568) There's no guarantee that module callbacks will handle cancellation appropriately. Protect module callbacks with read semantics from cancellation and avoid swallowing `CancelledError`s that arise. Other module callbacks, such as the `on_*` callbacks, are presumed to live on code paths that involve writes and aren't cancellation-friendly. These module callbacks have been left alone. Signed-off-by: Sean Quah --- synapse/handlers/account_validity.py | 3 ++- synapse/handlers/auth.py | 25 +++++++++++++++++++------ 2 files changed, 21 insertions(+), 7 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 05a138410e..33e45e3a11 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -23,6 +23,7 @@ from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.types import UserID from synapse.util import stringutils +from synapse.util.async_helpers import delay_cancellation if TYPE_CHECKING: from synapse.server import HomeServer @@ -150,7 +151,7 @@ class AccountValidityHandler: Whether the user has expired. """ for callback in self._is_user_expired_callbacks: - expired = await callback(user_id) + expired = await delay_cancellation(callback(user_id)) if expired is not None: return expired diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index ad41337b28..1b9050ea96 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -41,6 +41,7 @@ import pymacaroons import unpaddedbase64 from pymacaroons.exceptions import MacaroonVerificationFailedException +from twisted.internet.defer import CancelledError from twisted.web.server import Request from synapse.api.constants import LoginType @@ -67,7 +68,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.roommember import ProfileInfo from synapse.types import JsonDict, Requester, UserID from synapse.util import stringutils as stringutils -from synapse.util.async_helpers import maybe_awaitable +from synapse.util.async_helpers import delay_cancellation, maybe_awaitable from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import base62_encode @@ -2202,7 +2203,11 @@ class PasswordAuthProvider: # other than None (i.e. until a callback returns a success) for callback in self.auth_checker_callbacks[login_type]: try: - result = await callback(username, login_type, login_dict) + result = await delay_cancellation( + callback(username, login_type, login_dict) + ) + except CancelledError: + raise except Exception as e: logger.warning("Failed to run module API callback %s: %s", callback, e) continue @@ -2263,7 +2268,9 @@ class PasswordAuthProvider: for callback in self.check_3pid_auth_callbacks: try: - result = await callback(medium, address, password) + result = await delay_cancellation(callback(medium, address, password)) + except CancelledError: + raise except Exception as e: logger.warning("Failed to run module API callback %s: %s", callback, e) continue @@ -2345,7 +2352,7 @@ class PasswordAuthProvider: """ for callback in self.get_username_for_registration_callbacks: try: - res = await callback(uia_results, params) + res = await delay_cancellation(callback(uia_results, params)) if isinstance(res, str): return res @@ -2359,6 +2366,8 @@ class PasswordAuthProvider: callback, res, ) + except CancelledError: + raise except Exception as e: logger.error( "Module raised an exception in get_username_for_registration: %s", @@ -2388,7 +2397,7 @@ class PasswordAuthProvider: """ for callback in self.get_displayname_for_registration_callbacks: try: - res = await callback(uia_results, params) + res = await delay_cancellation(callback(uia_results, params)) if isinstance(res, str): return res @@ -2402,6 +2411,8 @@ class PasswordAuthProvider: callback, res, ) + except CancelledError: + raise except Exception as e: logger.error( "Module raised an exception in get_displayname_for_registration: %s", @@ -2429,7 +2440,7 @@ class PasswordAuthProvider: """ for callback in self.is_3pid_allowed_callbacks: try: - res = await callback(medium, address, registration) + res = await delay_cancellation(callback(medium, address, registration)) if res is False: return res @@ -2443,6 +2454,8 @@ class PasswordAuthProvider: callback, res, ) + except CancelledError: + raise except Exception as e: logger.error("Module raised an exception in is_3pid_allowed: %s", e) raise SynapseError(code=500, msg="Internal Server Error") -- cgit 1.4.1