summary refs log tree commit diff
path: root/synapse/util/async_helpers.py
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2021-07-28 20:55:50 +0100
committerGitHub <noreply@github.com>2021-07-28 19:55:50 +0000
commit858363d0b7e58fd71875b25d183537bb3b5a397f (patch)
tree080bc8232ee759309c29ee50c875e4bcddd39b2f /synapse/util/async_helpers.py
parentMake historical events discoverable from backfill for servers without any scr... (diff)
downloadsynapse-858363d0b7e58fd71875b25d183537bb3b5a397f.tar.xz
Generics for `ObservableDeferred` (#10491)
Now that `Deferred` is a generic class, let's update `ObeservableDeferred` to
follow suit.
Diffstat (limited to '')
-rw-r--r--synapse/util/async_helpers.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 912cf85f89..a3b65aee27 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -23,6 +23,7 @@ from typing import (
     Awaitable,
     Callable,
     Dict,
+    Generic,
     Hashable,
     Iterable,
     List,
@@ -39,6 +40,7 @@ from twisted.internet import defer
 from twisted.internet.defer import CancelledError
 from twisted.internet.interfaces import IReactorTime
 from twisted.python import failure
+from twisted.python.failure import Failure
 
 from synapse.logging.context import (
     PreserveLoggingContext,
@@ -52,7 +54,7 @@ logger = logging.getLogger(__name__)
 _T = TypeVar("_T")
 
 
-class ObservableDeferred:
+class ObservableDeferred(Generic[_T]):
     """Wraps a deferred object so that we can add observer deferreds. These
     observer deferreds do not affect the callback chain of the original
     deferred.
@@ -70,7 +72,7 @@ class ObservableDeferred:
 
     __slots__ = ["_deferred", "_observers", "_result"]
 
-    def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False):
+    def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
         object.__setattr__(self, "_deferred", deferred)
         object.__setattr__(self, "_result", None)
         object.__setattr__(self, "_observers", set())
@@ -115,7 +117,7 @@ class ObservableDeferred:
 
         deferred.addCallbacks(callback, errback)
 
-    def observe(self) -> defer.Deferred:
+    def observe(self) -> "defer.Deferred[_T]":
         """Observe the underlying deferred.
 
         This returns a brand new deferred that is resolved when the underlying
@@ -123,7 +125,7 @@ class ObservableDeferred:
         effect the underlying deferred.
         """
         if not self._result:
-            d: "defer.Deferred[Any]" = defer.Deferred()
+            d: "defer.Deferred[_T]" = defer.Deferred()
 
             def remove(r):
                 self._observers.discard(d)
@@ -137,7 +139,7 @@ class ObservableDeferred:
             success, res = self._result
             return defer.succeed(res) if success else defer.fail(res)
 
-    def observers(self) -> List[defer.Deferred]:
+    def observers(self) -> "List[defer.Deferred[_T]]":
         return self._observers
 
     def has_called(self) -> bool:
@@ -146,7 +148,7 @@ class ObservableDeferred:
     def has_succeeded(self) -> bool:
         return self._result is not None and self._result[0] is True
 
-    def get_result(self) -> Any:
+    def get_result(self) -> Union[_T, Failure]:
         return self._result[1]
 
     def __getattr__(self, name: str) -> Any: