diff options
Diffstat (limited to 'synapse/util')
-rw-r--r-- | synapse/util/async_helpers.py | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 912cf85f89..a3b65aee27 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -23,6 +23,7 @@ from typing import ( Awaitable, Callable, Dict, + Generic, Hashable, Iterable, List, @@ -39,6 +40,7 @@ from twisted.internet import defer from twisted.internet.defer import CancelledError from twisted.internet.interfaces import IReactorTime from twisted.python import failure +from twisted.python.failure import Failure from synapse.logging.context import ( PreserveLoggingContext, @@ -52,7 +54,7 @@ logger = logging.getLogger(__name__) _T = TypeVar("_T") -class ObservableDeferred: +class ObservableDeferred(Generic[_T]): """Wraps a deferred object so that we can add observer deferreds. These observer deferreds do not affect the callback chain of the original deferred. @@ -70,7 +72,7 @@ class ObservableDeferred: __slots__ = ["_deferred", "_observers", "_result"] - def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False): + def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False): object.__setattr__(self, "_deferred", deferred) object.__setattr__(self, "_result", None) object.__setattr__(self, "_observers", set()) @@ -115,7 +117,7 @@ class ObservableDeferred: deferred.addCallbacks(callback, errback) - def observe(self) -> defer.Deferred: + def observe(self) -> "defer.Deferred[_T]": """Observe the underlying deferred. This returns a brand new deferred that is resolved when the underlying @@ -123,7 +125,7 @@ class ObservableDeferred: effect the underlying deferred. """ if not self._result: - d: "defer.Deferred[Any]" = defer.Deferred() + d: "defer.Deferred[_T]" = defer.Deferred() def remove(r): self._observers.discard(d) @@ -137,7 +139,7 @@ class ObservableDeferred: success, res = self._result return defer.succeed(res) if success else defer.fail(res) - def observers(self) -> List[defer.Deferred]: + def observers(self) -> "List[defer.Deferred[_T]]": return self._observers def has_called(self) -> bool: @@ -146,7 +148,7 @@ class ObservableDeferred: def has_succeeded(self) -> bool: return self._result is not None and self._result[0] is True - def get_result(self) -> Any: + def get_result(self) -> Union[_T, Failure]: return self._result[1] def __getattr__(self, name: str) -> Any: |