summary refs log tree commit diff
path: root/synapse/util/async_helpers.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/async_helpers.py')
-rw-r--r--synapse/util/async_helpers.py131
1 files changed, 92 insertions, 39 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py

index 60c03a66fd..6a8e844d63 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py
@@ -18,9 +18,10 @@ import collections import inspect import itertools import logging -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager from typing import ( Any, + AsyncIterator, Awaitable, Callable, Collection, @@ -40,7 +41,7 @@ from typing import ( ) import attr -from typing_extensions import ContextManager +from typing_extensions import AsyncContextManager, Literal from twisted.internet import defer from twisted.internet.defer import CancelledError @@ -96,6 +97,10 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): __slots__ = ["_deferred", "_observers", "_result"] + _deferred: "defer.Deferred[_T]" + _observers: Union[List["defer.Deferred[_T]"], Tuple[()]] + _result: Union[None, Tuple[Literal[True], _T], Tuple[Literal[False], Failure]] + def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False): object.__setattr__(self, "_deferred", deferred) object.__setattr__(self, "_result", None) @@ -158,12 +163,14 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): effect the underlying deferred. """ if not self._result: + assert isinstance(self._observers, list) d: "defer.Deferred[_T]" = defer.Deferred() self._observers.append(d) return d + elif self._result[0]: + return defer.succeed(self._result[1]) else: - success, res = self._result - return defer.succeed(res) if success else defer.fail(res) + return defer.fail(self._result[1]) def observers(self) -> "Collection[defer.Deferred[_T]]": return self._observers @@ -175,6 +182,8 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): return self._result is not None and self._result[0] is True def get_result(self) -> Union[_T, Failure]: + if self._result is None: + raise ValueError(f"{self!r} has no result yet") return self._result[1] def __getattr__(self, name: str) -> Any: @@ -483,7 +492,7 @@ class ReadWriteLock: Example: - with await read_write_lock.read("test_key"): + async with read_write_lock.read("test_key"): # do some work """ @@ -506,22 +515,24 @@ class ReadWriteLock: # Latest writer queued self.key_to_current_writer: Dict[str, defer.Deferred] = {} - async def read(self, key: str) -> ContextManager: - new_defer: "defer.Deferred[None]" = defer.Deferred() - - curr_readers = self.key_to_current_readers.setdefault(key, set()) - curr_writer = self.key_to_current_writer.get(key, None) + def read(self, key: str) -> AsyncContextManager: + @asynccontextmanager + async def _ctx_manager() -> AsyncIterator[None]: + new_defer: "defer.Deferred[None]" = defer.Deferred() - curr_readers.add(new_defer) + curr_readers = self.key_to_current_readers.setdefault(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) - # We wait for the latest writer to finish writing. We can safely ignore - # any existing readers... as they're readers. - if curr_writer: - await make_deferred_yieldable(curr_writer) + curr_readers.add(new_defer) - @contextmanager - def _ctx_manager() -> Iterator[None]: try: + # We wait for the latest writer to finish writing. We can safely ignore + # any existing readers... as they're readers. + # May raise a `CancelledError` if the `Deferred` wrapping us is + # cancelled. The `Deferred` we are waiting on must not be cancelled, + # since we do not own it. + if curr_writer: + await make_deferred_yieldable(stop_cancellation(curr_writer)) yield finally: with PreserveLoggingContext(): @@ -530,29 +541,35 @@ class ReadWriteLock: return _ctx_manager() - async def write(self, key: str) -> ContextManager: - new_defer: "defer.Deferred[None]" = defer.Deferred() + def write(self, key: str) -> AsyncContextManager: + @asynccontextmanager + async def _ctx_manager() -> AsyncIterator[None]: + new_defer: "defer.Deferred[None]" = defer.Deferred() - curr_readers = self.key_to_current_readers.get(key, set()) - curr_writer = self.key_to_current_writer.get(key, None) + curr_readers = self.key_to_current_readers.get(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) - # We wait on all latest readers and writer. - to_wait_on = list(curr_readers) - if curr_writer: - to_wait_on.append(curr_writer) + # We wait on all latest readers and writer. + to_wait_on = list(curr_readers) + if curr_writer: + to_wait_on.append(curr_writer) - # We can clear the list of current readers since the new writer waits - # for them to finish. - curr_readers.clear() - self.key_to_current_writer[key] = new_defer + # We can clear the list of current readers since `new_defer` waits + # for them to finish. + curr_readers.clear() + self.key_to_current_writer[key] = new_defer - await make_deferred_yieldable(defer.gatherResults(to_wait_on)) - - @contextmanager - def _ctx_manager() -> Iterator[None]: + to_wait_on_defer = defer.gatherResults(to_wait_on) try: + # Wait for all current readers and the latest writer to finish. + # May raise a `CancelledError` immediately after the wait if the + # `Deferred` wrapping us is cancelled. We must only release the lock + # once we have acquired it, hence the use of `delay_cancellation` + # rather than `stop_cancellation`. + await make_deferred_yieldable(delay_cancellation(to_wait_on_defer)) yield finally: + # Release the lock. with PreserveLoggingContext(): new_defer.callback(None) # `self.key_to_current_writer[key]` may be missing if there was another @@ -678,12 +695,48 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": Synapse logcontext rules. Returns: - A new `Deferred`, which will contain the result of the original `Deferred`, - but will not propagate cancellation through to the original. When cancelled, - the new `Deferred` will fail with a `CancelledError` and will not follow the - Synapse logcontext rules. `make_deferred_yieldable` should be used to wrap - the new `Deferred`. + A new `Deferred`, which will contain the result of the original `Deferred`. + The new `Deferred` will not propagate cancellation through to the original. + When cancelled, the new `Deferred` will fail with a `CancelledError`. + + The new `Deferred` will not follow the Synapse logcontext rules and should be + wrapped with `make_deferred_yieldable`. """ - new_deferred: defer.Deferred[T] = defer.Deferred() + new_deferred: "defer.Deferred[T]" = defer.Deferred() + deferred.chainDeferred(new_deferred) + return new_deferred + + +def delay_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": + """Delay cancellation of a `Deferred` until it resolves. + + Has the same effect as `stop_cancellation`, but the returned `Deferred` will not + resolve with a `CancelledError` until the original `Deferred` resolves. + + Args: + deferred: The `Deferred` to protect against cancellation. May optionally follow + the Synapse logcontext rules. + + Returns: + A new `Deferred`, which will contain the result of the original `Deferred`. + The new `Deferred` will not propagate cancellation through to the original. + When cancelled, the new `Deferred` will wait until the original `Deferred` + resolves before failing with a `CancelledError`. + + The new `Deferred` will follow the Synapse logcontext rules if `deferred` + follows the Synapse logcontext rules. Otherwise the new `Deferred` should be + wrapped with `make_deferred_yieldable`. + """ + + def handle_cancel(new_deferred: "defer.Deferred[T]") -> None: + # before the new deferred is cancelled, we `pause` it to stop the cancellation + # propagating. we then `unpause` it once the wrapped deferred completes, to + # propagate the exception. + new_deferred.pause() + new_deferred.errback(Failure(CancelledError())) + + deferred.addBoth(lambda _: new_deferred.unpause()) + + new_deferred: "defer.Deferred[T]" = defer.Deferred(handle_cancel) deferred.chainDeferred(new_deferred) return new_deferred