From 858363d0b7e58fd71875b25d183537bb3b5a397f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 28 Jul 2021 20:55:50 +0100 Subject: Generics for `ObservableDeferred` (#10491) Now that `Deferred` is a generic class, let's update `ObeservableDeferred` to follow suit. --- changelog.d/10491.misc | 1 + synapse/notifier.py | 5 +++-- synapse/storage/persist_events.py | 4 +++- synapse/util/async_helpers.py | 14 ++++++++------ 4 files changed, 15 insertions(+), 9 deletions(-) create mode 100644 changelog.d/10491.misc diff --git a/changelog.d/10491.misc b/changelog.d/10491.misc new file mode 100644 index 0000000000..3867cf2682 --- /dev/null +++ b/changelog.d/10491.misc @@ -0,0 +1 @@ +Improve type annotations for `ObservableDeferred`. diff --git a/synapse/notifier.py b/synapse/notifier.py index c5fbebc17d..bbe337949a 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -111,8 +111,9 @@ class _NotifierUserStream: self.last_notified_token = current_token self.last_notified_ms = time_now_ms - with PreserveLoggingContext(): - self.notify_deferred = ObservableDeferred(defer.Deferred()) + self.notify_deferred: ObservableDeferred[StreamToken] = ObservableDeferred( + defer.Deferred() + ) def notify( self, diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index a39877f0d5..0e8270746d 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -170,7 +170,9 @@ class _EventPeristenceQueue(Generic[_PersistResult]): end_item = queue[-1] else: # need to make a new queue item - deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) + deferred: ObservableDeferred[_PersistResult] = ObservableDeferred( + defer.Deferred(), consumeErrors=True + ) end_item = _EventPersistQueueItem( events_and_contexts=[], diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 912cf85f89..a3b65aee27 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -23,6 +23,7 @@ from typing import ( Awaitable, Callable, Dict, + Generic, Hashable, Iterable, List, @@ -39,6 +40,7 @@ from twisted.internet import defer from twisted.internet.defer import CancelledError from twisted.internet.interfaces import IReactorTime from twisted.python import failure +from twisted.python.failure import Failure from synapse.logging.context import ( PreserveLoggingContext, @@ -52,7 +54,7 @@ logger = logging.getLogger(__name__) _T = TypeVar("_T") -class ObservableDeferred: +class ObservableDeferred(Generic[_T]): """Wraps a deferred object so that we can add observer deferreds. These observer deferreds do not affect the callback chain of the original deferred. @@ -70,7 +72,7 @@ class ObservableDeferred: __slots__ = ["_deferred", "_observers", "_result"] - def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False): + def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False): object.__setattr__(self, "_deferred", deferred) object.__setattr__(self, "_result", None) object.__setattr__(self, "_observers", set()) @@ -115,7 +117,7 @@ class ObservableDeferred: deferred.addCallbacks(callback, errback) - def observe(self) -> defer.Deferred: + def observe(self) -> "defer.Deferred[_T]": """Observe the underlying deferred. This returns a brand new deferred that is resolved when the underlying @@ -123,7 +125,7 @@ class ObservableDeferred: effect the underlying deferred. """ if not self._result: - d: "defer.Deferred[Any]" = defer.Deferred() + d: "defer.Deferred[_T]" = defer.Deferred() def remove(r): self._observers.discard(d) @@ -137,7 +139,7 @@ class ObservableDeferred: success, res = self._result return defer.succeed(res) if success else defer.fail(res) - def observers(self) -> List[defer.Deferred]: + def observers(self) -> "List[defer.Deferred[_T]]": return self._observers def has_called(self) -> bool: @@ -146,7 +148,7 @@ class ObservableDeferred: def has_succeeded(self) -> bool: return self._result is not None and self._result[0] is True - def get_result(self) -> Any: + def get_result(self) -> Union[_T, Failure]: return self._result[1] def __getattr__(self, name: str) -> Any: -- cgit 1.4.1