diff --git a/changelog.d/10491.misc b/changelog.d/10491.misc
new file mode 100644
index 0000000000..3867cf2682
--- /dev/null
+++ b/changelog.d/10491.misc
@@ -0,0 +1 @@
+Improve type annotations for `ObservableDeferred`.
diff --git a/synapse/notifier.py b/synapse/notifier.py
index c5fbebc17d..bbe337949a 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -111,8 +111,9 @@ class _NotifierUserStream:
self.last_notified_token = current_token
self.last_notified_ms = time_now_ms
- with PreserveLoggingContext():
- self.notify_deferred = ObservableDeferred(defer.Deferred())
+ self.notify_deferred: ObservableDeferred[StreamToken] = ObservableDeferred(
+ defer.Deferred()
+ )
def notify(
self,
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index a39877f0d5..0e8270746d 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -170,7 +170,9 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
end_item = queue[-1]
else:
# need to make a new queue item
- deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
+ deferred: ObservableDeferred[_PersistResult] = ObservableDeferred(
+ defer.Deferred(), consumeErrors=True
+ )
end_item = _EventPersistQueueItem(
events_and_contexts=[],
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 912cf85f89..a3b65aee27 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -23,6 +23,7 @@ from typing import (
Awaitable,
Callable,
Dict,
+ Generic,
Hashable,
Iterable,
List,
@@ -39,6 +40,7 @@ from twisted.internet import defer
from twisted.internet.defer import CancelledError
from twisted.internet.interfaces import IReactorTime
from twisted.python import failure
+from twisted.python.failure import Failure
from synapse.logging.context import (
PreserveLoggingContext,
@@ -52,7 +54,7 @@ logger = logging.getLogger(__name__)
_T = TypeVar("_T")
-class ObservableDeferred:
+class ObservableDeferred(Generic[_T]):
"""Wraps a deferred object so that we can add observer deferreds. These
observer deferreds do not affect the callback chain of the original
deferred.
@@ -70,7 +72,7 @@ class ObservableDeferred:
__slots__ = ["_deferred", "_observers", "_result"]
- def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False):
+ def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
object.__setattr__(self, "_deferred", deferred)
object.__setattr__(self, "_result", None)
object.__setattr__(self, "_observers", set())
@@ -115,7 +117,7 @@ class ObservableDeferred:
deferred.addCallbacks(callback, errback)
- def observe(self) -> defer.Deferred:
+ def observe(self) -> "defer.Deferred[_T]":
"""Observe the underlying deferred.
This returns a brand new deferred that is resolved when the underlying
@@ -123,7 +125,7 @@ class ObservableDeferred:
effect the underlying deferred.
"""
if not self._result:
- d: "defer.Deferred[Any]" = defer.Deferred()
+ d: "defer.Deferred[_T]" = defer.Deferred()
def remove(r):
self._observers.discard(d)
@@ -137,7 +139,7 @@ class ObservableDeferred:
success, res = self._result
return defer.succeed(res) if success else defer.fail(res)
- def observers(self) -> List[defer.Deferred]:
+ def observers(self) -> "List[defer.Deferred[_T]]":
return self._observers
def has_called(self) -> bool:
@@ -146,7 +148,7 @@ class ObservableDeferred:
def has_succeeded(self) -> bool:
return self._result is not None and self._result[0] is True
- def get_result(self) -> Any:
+ def get_result(self) -> Union[_T, Failure]:
return self._result[1]
def __getattr__(self, name: str) -> Any:
|