summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/async_helpers.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 912cf85f89..a3b65aee27 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -23,6 +23,7 @@ from typing import (
     Awaitable,
     Callable,
     Dict,
+    Generic,
     Hashable,
     Iterable,
     List,
@@ -39,6 +40,7 @@ from twisted.internet import defer
 from twisted.internet.defer import CancelledError
 from twisted.internet.interfaces import IReactorTime
 from twisted.python import failure
+from twisted.python.failure import Failure
 
 from synapse.logging.context import (
     PreserveLoggingContext,
@@ -52,7 +54,7 @@ logger = logging.getLogger(__name__)
 _T = TypeVar("_T")
 
 
-class ObservableDeferred:
+class ObservableDeferred(Generic[_T]):
     """Wraps a deferred object so that we can add observer deferreds. These
     observer deferreds do not affect the callback chain of the original
     deferred.
@@ -70,7 +72,7 @@ class ObservableDeferred:
 
     __slots__ = ["_deferred", "_observers", "_result"]
 
-    def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False):
+    def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
         object.__setattr__(self, "_deferred", deferred)
         object.__setattr__(self, "_result", None)
         object.__setattr__(self, "_observers", set())
@@ -115,7 +117,7 @@ class ObservableDeferred:
 
         deferred.addCallbacks(callback, errback)
 
-    def observe(self) -> defer.Deferred:
+    def observe(self) -> "defer.Deferred[_T]":
         """Observe the underlying deferred.
 
         This returns a brand new deferred that is resolved when the underlying
@@ -123,7 +125,7 @@ class ObservableDeferred:
         effect the underlying deferred.
         """
         if not self._result:
-            d: "defer.Deferred[Any]" = defer.Deferred()
+            d: "defer.Deferred[_T]" = defer.Deferred()
 
             def remove(r):
                 self._observers.discard(d)
@@ -137,7 +139,7 @@ class ObservableDeferred:
             success, res = self._result
             return defer.succeed(res) if success else defer.fail(res)
 
-    def observers(self) -> List[defer.Deferred]:
+    def observers(self) -> "List[defer.Deferred[_T]]":
         return self._observers
 
     def has_called(self) -> bool:
@@ -146,7 +148,7 @@ class ObservableDeferred:
     def has_succeeded(self) -> bool:
         return self._result is not None and self._result[0] is True
 
-    def get_result(self) -> Any:
+    def get_result(self) -> Union[_T, Failure]:
         return self._result[1]
 
     def __getattr__(self, name: str) -> Any: