summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
authorSean Quah <8349537+squahtx@users.noreply.github.com>2022-05-09 12:31:14 +0100
committerGitHub <noreply@github.com>2022-05-09 12:31:14 +0100
commita00462dd9927558532b030593f8914ade53b7214 (patch)
treed15933417e46d7b451298238b39bced69ebffa0b /synapse/handlers
parentFix mypy against latest pillow stubs (#12671) (diff)
downloadsynapse-a00462dd9927558532b030593f8914ade53b7214.tar.xz
Implement cancellation support/protection for module callbacks (#12568)
There's no guarantee that module callbacks will handle cancellation
appropriately. Protect module callbacks with read semantics from
cancellation and avoid swallowing `CancelledError`s that arise.

Other module callbacks, such as the `on_*` callbacks, are presumed to
live on code paths that involve writes and aren't cancellation-friendly.
These module callbacks have been left alone.

Signed-off-by: Sean Quah <seanq@element.io>
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/account_validity.py3
-rw-r--r--synapse/handlers/auth.py25
2 files changed, 21 insertions, 7 deletions
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 05a138410e..33e45e3a11 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -23,6 +23,7 @@ from synapse.api.errors import AuthError, StoreError, SynapseError
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.types import UserID
 from synapse.util import stringutils
+from synapse.util.async_helpers import delay_cancellation
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -150,7 +151,7 @@ class AccountValidityHandler:
             Whether the user has expired.
         """
         for callback in self._is_user_expired_callbacks:
-            expired = await callback(user_id)
+            expired = await delay_cancellation(callback(user_id))
             if expired is not None:
                 return expired
 
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index ad41337b28..1b9050ea96 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -41,6 +41,7 @@ import pymacaroons
 import unpaddedbase64
 from pymacaroons.exceptions import MacaroonVerificationFailedException
 
+from twisted.internet.defer import CancelledError
 from twisted.web.server import Request
 
 from synapse.api.constants import LoginType
@@ -67,7 +68,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.roommember import ProfileInfo
 from synapse.types import JsonDict, Requester, UserID
 from synapse.util import stringutils as stringutils
-from synapse.util.async_helpers import maybe_awaitable
+from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
 from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.stringutils import base62_encode
@@ -2202,7 +2203,11 @@ class PasswordAuthProvider:
         # other than None (i.e. until a callback returns a success)
         for callback in self.auth_checker_callbacks[login_type]:
             try:
-                result = await callback(username, login_type, login_dict)
+                result = await delay_cancellation(
+                    callback(username, login_type, login_dict)
+                )
+            except CancelledError:
+                raise
             except Exception as e:
                 logger.warning("Failed to run module API callback %s: %s", callback, e)
                 continue
@@ -2263,7 +2268,9 @@ class PasswordAuthProvider:
 
         for callback in self.check_3pid_auth_callbacks:
             try:
-                result = await callback(medium, address, password)
+                result = await delay_cancellation(callback(medium, address, password))
+            except CancelledError:
+                raise
             except Exception as e:
                 logger.warning("Failed to run module API callback %s: %s", callback, e)
                 continue
@@ -2345,7 +2352,7 @@ class PasswordAuthProvider:
         """
         for callback in self.get_username_for_registration_callbacks:
             try:
-                res = await callback(uia_results, params)
+                res = await delay_cancellation(callback(uia_results, params))
 
                 if isinstance(res, str):
                     return res
@@ -2359,6 +2366,8 @@ class PasswordAuthProvider:
                         callback,
                         res,
                     )
+            except CancelledError:
+                raise
             except Exception as e:
                 logger.error(
                     "Module raised an exception in get_username_for_registration: %s",
@@ -2388,7 +2397,7 @@ class PasswordAuthProvider:
         """
         for callback in self.get_displayname_for_registration_callbacks:
             try:
-                res = await callback(uia_results, params)
+                res = await delay_cancellation(callback(uia_results, params))
 
                 if isinstance(res, str):
                     return res
@@ -2402,6 +2411,8 @@ class PasswordAuthProvider:
                         callback,
                         res,
                     )
+            except CancelledError:
+                raise
             except Exception as e:
                 logger.error(
                     "Module raised an exception in get_displayname_for_registration: %s",
@@ -2429,7 +2440,7 @@ class PasswordAuthProvider:
         """
         for callback in self.is_3pid_allowed_callbacks:
             try:
-                res = await callback(medium, address, registration)
+                res = await delay_cancellation(callback(medium, address, registration))
 
                 if res is False:
                     return res
@@ -2443,6 +2454,8 @@ class PasswordAuthProvider:
                         callback,
                         res,
                     )
+            except CancelledError:
+                raise
             except Exception as e:
                 logger.error("Module raised an exception in is_3pid_allowed: %s", e)
                 raise SynapseError(code=500, msg="Internal Server Error")