summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
authorSean Quah <8349537+squahtx@users.noreply.github.com>2022-03-04 15:37:02 +0000
committerGitHub <noreply@github.com>2022-03-04 15:37:02 +0000
commit75574726a766f09d955c05672d400c65cb341810 (patch)
treed0ba13875c21e06dfa4c87ee7c9225f363fc4ea4 /synapse/util
parentAdd test for `ObservableDeferred`'s cancellation behaviour (#12149) (diff)
downloadsynapse-75574726a766f09d955c05672d400c65cb341810.tar.xz
Add type hints for `ObservableDeferred` attributes (#12159)
Signed-off-by: Sean Quah <seanq@element.io>
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/async_helpers.py14
1 files changed, 11 insertions, 3 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 60c03a66fd..a9f67dcbac 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -40,7 +40,7 @@ from typing import (
 )
 
 import attr
-from typing_extensions import ContextManager
+from typing_extensions import ContextManager, Literal
 
 from twisted.internet import defer
 from twisted.internet.defer import CancelledError
@@ -96,6 +96,10 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
 
     __slots__ = ["_deferred", "_observers", "_result"]
 
+    _deferred: "defer.Deferred[_T]"
+    _observers: Union[List["defer.Deferred[_T]"], Tuple[()]]
+    _result: Union[None, Tuple[Literal[True], _T], Tuple[Literal[False], Failure]]
+
     def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
         object.__setattr__(self, "_deferred", deferred)
         object.__setattr__(self, "_result", None)
@@ -158,12 +162,14 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
         effect the underlying deferred.
         """
         if not self._result:
+            assert isinstance(self._observers, list)
             d: "defer.Deferred[_T]" = defer.Deferred()
             self._observers.append(d)
             return d
+        elif self._result[0]:
+            return defer.succeed(self._result[1])
         else:
-            success, res = self._result
-            return defer.succeed(res) if success else defer.fail(res)
+            return defer.fail(self._result[1])
 
     def observers(self) -> "Collection[defer.Deferred[_T]]":
         return self._observers
@@ -175,6 +181,8 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
         return self._result is not None and self._result[0] is True
 
     def get_result(self) -> Union[_T, Failure]:
+        if self._result is None:
+            raise ValueError(f"{self!r} has no result yet")
         return self._result[1]
 
     def __getattr__(self, name: str) -> Any: