diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 014db1355b..a3b65aee27 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -23,6 +23,7 @@ from typing import (
Awaitable,
Callable,
Dict,
+ Generic,
Hashable,
Iterable,
List,
@@ -39,6 +40,7 @@ from twisted.internet import defer
from twisted.internet.defer import CancelledError
from twisted.internet.interfaces import IReactorTime
from twisted.python import failure
+from twisted.python.failure import Failure
from synapse.logging.context import (
PreserveLoggingContext,
@@ -49,8 +51,10 @@ from synapse.util import Clock, unwrapFirstError
logger = logging.getLogger(__name__)
+_T = TypeVar("_T")
-class ObservableDeferred:
+
+class ObservableDeferred(Generic[_T]):
"""Wraps a deferred object so that we can add observer deferreds. These
observer deferreds do not affect the callback chain of the original
deferred.
@@ -68,7 +72,7 @@ class ObservableDeferred:
__slots__ = ["_deferred", "_observers", "_result"]
- def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False):
+ def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
object.__setattr__(self, "_deferred", deferred)
object.__setattr__(self, "_result", None)
object.__setattr__(self, "_observers", set())
@@ -113,7 +117,7 @@ class ObservableDeferred:
deferred.addCallbacks(callback, errback)
- def observe(self) -> defer.Deferred:
+ def observe(self) -> "defer.Deferred[_T]":
"""Observe the underlying deferred.
This returns a brand new deferred that is resolved when the underlying
@@ -121,7 +125,7 @@ class ObservableDeferred:
effect the underlying deferred.
"""
if not self._result:
- d = defer.Deferred()
+ d: "defer.Deferred[_T]" = defer.Deferred()
def remove(r):
self._observers.discard(d)
@@ -135,7 +139,7 @@ class ObservableDeferred:
success, res = self._result
return defer.succeed(res) if success else defer.fail(res)
- def observers(self) -> List[defer.Deferred]:
+ def observers(self) -> "List[defer.Deferred[_T]]":
return self._observers
def has_called(self) -> bool:
@@ -144,7 +148,7 @@ class ObservableDeferred:
def has_succeeded(self) -> bool:
return self._result is not None and self._result[0] is True
- def get_result(self) -> Any:
+ def get_result(self) -> Union[_T, Failure]:
return self._result[1]
def __getattr__(self, name: str) -> Any:
@@ -415,7 +419,7 @@ class ReadWriteLock:
self.key_to_current_writer: Dict[str, defer.Deferred] = {}
async def read(self, key: str) -> ContextManager:
- new_defer = defer.Deferred()
+ new_defer: "defer.Deferred[None]" = defer.Deferred()
curr_readers = self.key_to_current_readers.setdefault(key, set())
curr_writer = self.key_to_current_writer.get(key, None)
@@ -438,7 +442,7 @@ class ReadWriteLock:
return _ctx_manager()
async def write(self, key: str) -> ContextManager:
- new_defer = defer.Deferred()
+ new_defer: "defer.Deferred[None]" = defer.Deferred()
curr_readers = self.key_to_current_readers.get(key, set())
curr_writer = self.key_to_current_writer.get(key, None)
@@ -471,10 +475,8 @@ R = TypeVar("R")
def timeout_deferred(
- deferred: defer.Deferred,
- timeout: float,
- reactor: IReactorTime,
-) -> defer.Deferred:
+ deferred: "defer.Deferred[_T]", timeout: float, reactor: IReactorTime
+) -> "defer.Deferred[_T]":
"""The in built twisted `Deferred.addTimeout` fails to time out deferreds
that have a canceller that throws exceptions. This method creates a new
deferred that wraps and times out the given deferred, correctly handling
@@ -497,7 +499,7 @@ def timeout_deferred(
Returns:
A new Deferred, which will errback with defer.TimeoutError on timeout.
"""
- new_d = defer.Deferred()
+ new_d: "defer.Deferred[_T]" = defer.Deferred()
timed_out = [False]
diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py
index 891bee0b33..e58dd91eda 100644
--- a/synapse/util/caches/cached_call.py
+++ b/synapse/util/caches/cached_call.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import enum
from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union
from twisted.internet.defer import Deferred
@@ -22,6 +22,10 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
TV = TypeVar("TV")
+class _Sentinel(enum.Enum):
+ sentinel = object()
+
+
class CachedCall(Generic[TV]):
"""A wrapper for asynchronous calls whose results should be shared
@@ -65,7 +69,7 @@ class CachedCall(Generic[TV]):
"""
self._callable: Optional[Callable[[], Awaitable[TV]]] = f
self._deferred: Optional[Deferred] = None
- self._result: Union[None, Failure, TV] = None
+ self._result: Union[_Sentinel, TV, Failure] = _Sentinel.sentinel
async def get(self) -> TV:
"""Kick off the call if necessary, and return the result"""
@@ -78,8 +82,9 @@ class CachedCall(Generic[TV]):
self._callable = None
# once the deferred completes, store the result. We cannot simply leave the
- # result in the deferred, since if it's a Failure, GCing the deferred
- # would then log a critical error about unhandled Failures.
+ # result in the deferred, since `awaiting` a deferred destroys its result.
+ # (Also, if it's a Failure, GCing the deferred would log a critical error
+ # about unhandled Failures)
def got_result(r):
self._result = r
@@ -92,13 +97,15 @@ class CachedCall(Generic[TV]):
# and any eventual exception may not be reported.
# we can now await the deferred, and once it completes, return the result.
- await make_deferred_yieldable(self._deferred)
+ if isinstance(self._result, _Sentinel):
+ await make_deferred_yieldable(self._deferred)
+ assert not isinstance(self._result, _Sentinel)
+
+ if isinstance(self._result, Failure):
+ self._result.raiseException()
+ raise AssertionError("unexpected return from Failure.raiseException")
- # I *think* this is the easiest way to correctly raise a Failure without having
- # to gut-wrench into the implementation of Deferred.
- d = Deferred()
- d.callback(self._result)
- return await d
+ return self._result
class RetryOnExceptionCachedCall(Generic[TV]):
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 8c6fafc677..b6456392cd 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -16,7 +16,16 @@
import enum
import threading
-from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, Union
+from typing import (
+ Callable,
+ Generic,
+ Iterable,
+ MutableMapping,
+ Optional,
+ TypeVar,
+ Union,
+ cast,
+)
from prometheus_client import Gauge
@@ -166,7 +175,7 @@ class DeferredCache(Generic[KT, VT]):
def set(
self,
key: KT,
- value: defer.Deferred,
+ value: "defer.Deferred[VT]",
callback: Optional[Callable[[], None]] = None,
) -> defer.Deferred:
"""Adds a new entry to the cache (or updates an existing one).
@@ -214,7 +223,7 @@ class DeferredCache(Generic[KT, VT]):
if value.called:
result = value.result
if not isinstance(result, failure.Failure):
- self.cache.set(key, result, callbacks)
+ self.cache.set(key, cast(VT, result), callbacks)
return value
# otherwise, we'll add an entry to the _pending_deferred_cache for now,
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 1e8e6b1d01..1ca31e41ac 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -413,7 +413,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
# relevant result for that key.
deferreds_map = {}
for arg in missing:
- deferred = defer.Deferred()
+ deferred: "defer.Deferred[Any]" = defer.Deferred()
deferreds_map[arg] = deferred
key = arg_to_cache_key(arg)
cache.set(key, deferred, callback=invalidate_callback)
|