summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/10491.misc1
-rw-r--r--synapse/notifier.py5
-rw-r--r--synapse/storage/persist_events.py4
-rw-r--r--synapse/util/async_helpers.py14
4 files changed, 15 insertions, 9 deletions
diff --git a/changelog.d/10491.misc b/changelog.d/10491.misc
new file mode 100644
index 0000000000..3867cf2682
--- /dev/null
+++ b/changelog.d/10491.misc
@@ -0,0 +1 @@
+Improve type annotations for `ObservableDeferred`.
diff --git a/synapse/notifier.py b/synapse/notifier.py
index c5fbebc17d..bbe337949a 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -111,8 +111,9 @@ class _NotifierUserStream:
         self.last_notified_token = current_token
         self.last_notified_ms = time_now_ms
 
-        with PreserveLoggingContext():
-            self.notify_deferred = ObservableDeferred(defer.Deferred())
+        self.notify_deferred: ObservableDeferred[StreamToken] = ObservableDeferred(
+            defer.Deferred()
+        )
 
     def notify(
         self,
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index a39877f0d5..0e8270746d 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -170,7 +170,9 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
             end_item = queue[-1]
         else:
             # need to make a new queue item
-            deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
+            deferred: ObservableDeferred[_PersistResult] = ObservableDeferred(
+                defer.Deferred(), consumeErrors=True
+            )
 
             end_item = _EventPersistQueueItem(
                 events_and_contexts=[],
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 912cf85f89..a3b65aee27 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -23,6 +23,7 @@ from typing import (
     Awaitable,
     Callable,
     Dict,
+    Generic,
     Hashable,
     Iterable,
     List,
@@ -39,6 +40,7 @@ from twisted.internet import defer
 from twisted.internet.defer import CancelledError
 from twisted.internet.interfaces import IReactorTime
 from twisted.python import failure
+from twisted.python.failure import Failure
 
 from synapse.logging.context import (
     PreserveLoggingContext,
@@ -52,7 +54,7 @@ logger = logging.getLogger(__name__)
 _T = TypeVar("_T")
 
 
-class ObservableDeferred:
+class ObservableDeferred(Generic[_T]):
     """Wraps a deferred object so that we can add observer deferreds. These
     observer deferreds do not affect the callback chain of the original
     deferred.
@@ -70,7 +72,7 @@ class ObservableDeferred:
 
     __slots__ = ["_deferred", "_observers", "_result"]
 
-    def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False):
+    def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
         object.__setattr__(self, "_deferred", deferred)
         object.__setattr__(self, "_result", None)
         object.__setattr__(self, "_observers", set())
@@ -115,7 +117,7 @@ class ObservableDeferred:
 
         deferred.addCallbacks(callback, errback)
 
-    def observe(self) -> defer.Deferred:
+    def observe(self) -> "defer.Deferred[_T]":
         """Observe the underlying deferred.
 
         This returns a brand new deferred that is resolved when the underlying
@@ -123,7 +125,7 @@ class ObservableDeferred:
         effect the underlying deferred.
         """
         if not self._result:
-            d: "defer.Deferred[Any]" = defer.Deferred()
+            d: "defer.Deferred[_T]" = defer.Deferred()
 
             def remove(r):
                 self._observers.discard(d)
@@ -137,7 +139,7 @@ class ObservableDeferred:
             success, res = self._result
             return defer.succeed(res) if success else defer.fail(res)
 
-    def observers(self) -> List[defer.Deferred]:
+    def observers(self) -> "List[defer.Deferred[_T]]":
         return self._observers
 
     def has_called(self) -> bool:
@@ -146,7 +148,7 @@ class ObservableDeferred:
     def has_succeeded(self) -> bool:
         return self._result is not None and self._result[0] is True
 
-    def get_result(self) -> Any:
+    def get_result(self) -> Union[_T, Failure]:
         return self._result[1]
 
     def __getattr__(self, name: str) -> Any: