summary refs log tree commit diff
path: root/synapse/util/async_helpers.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/async_helpers.py')
-rw-r--r--synapse/util/async_helpers.py37
1 files changed, 18 insertions, 19 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 96efc5f3e3..20ce294209 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -27,20 +27,20 @@ from typing import (
     Generic,
     Hashable,
     Iterable,
+    Iterator,
     Optional,
     Set,
     TypeVar,
     Union,
+    cast,
 )
 
 import attr
 from typing_extensions import ContextManager
 
 from twisted.internet import defer
-from twisted.internet.base import ReactorBase
 from twisted.internet.defer import CancelledError
 from twisted.internet.interfaces import IReactorTime
-from twisted.python import failure
 from twisted.python.failure import Failure
 
 from synapse.logging.context import (
@@ -78,7 +78,7 @@ class ObservableDeferred(Generic[_T]):
         object.__setattr__(self, "_result", None)
         object.__setattr__(self, "_observers", [])
 
-        def callback(r):
+        def callback(r: _T) -> _T:
             object.__setattr__(self, "_result", (True, r))
 
             # once we have set _result, no more entries will be added to _observers,
@@ -98,7 +98,7 @@ class ObservableDeferred(Generic[_T]):
                     )
             return r
 
-        def errback(f):
+        def errback(f: Failure) -> Optional[Failure]:
             object.__setattr__(self, "_result", (False, f))
 
             # once we have set _result, no more entries will be added to _observers,
@@ -109,7 +109,7 @@ class ObservableDeferred(Generic[_T]):
             for observer in observers:
                 # This is a little bit of magic to correctly propagate stack
                 # traces when we `await` on one of the observer deferreds.
-                f.value.__failure__ = f
+                f.value.__failure__ = f  # type: ignore[union-attr]
                 try:
                     observer.errback(f)
                 except Exception as e:
@@ -271,8 +271,7 @@ class Linearizer:
         if not clock:
             from twisted.internet import reactor
 
-            assert isinstance(reactor, ReactorBase)
-            clock = Clock(reactor)
+            clock = Clock(cast(IReactorTime, reactor))
         self._clock = clock
         self.max_count = max_count
 
@@ -315,7 +314,7 @@ class Linearizer:
         # will release the lock.
 
         @contextmanager
-        def _ctx_manager(_):
+        def _ctx_manager(_: None) -> Iterator[None]:
             try:
                 yield
             finally:
@@ -356,7 +355,7 @@ class Linearizer:
         new_defer = make_deferred_yieldable(defer.Deferred())
         entry.deferreds[new_defer] = 1
 
-        def cb(_r):
+        def cb(_r: None) -> "defer.Deferred[None]":
             logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
             entry.count += 1
 
@@ -372,7 +371,7 @@ class Linearizer:
             # code must be synchronous, so this is the only sensible place.)
             return self._clock.sleep(0)
 
-        def eb(e):
+        def eb(e: Failure) -> Failure:
             logger.info("defer %r got err %r", new_defer, e)
             if isinstance(e, CancelledError):
                 logger.debug(
@@ -436,7 +435,7 @@ class ReadWriteLock:
             await make_deferred_yieldable(curr_writer)
 
         @contextmanager
-        def _ctx_manager():
+        def _ctx_manager() -> Iterator[None]:
             try:
                 yield
             finally:
@@ -465,7 +464,7 @@ class ReadWriteLock:
         await make_deferred_yieldable(defer.gatherResults(to_wait_on))
 
         @contextmanager
-        def _ctx_manager():
+        def _ctx_manager() -> Iterator[None]:
             try:
                 yield
             finally:
@@ -525,7 +524,7 @@ def timeout_deferred(
 
     delayed_call = reactor.callLater(timeout, time_it_out)
 
-    def convert_cancelled(value: failure.Failure):
+    def convert_cancelled(value: Failure) -> Failure:
         # if the original deferred was cancelled, and our timeout has fired, then
         # the reason it was cancelled was due to our timeout. Turn the CancelledError
         # into a TimeoutError.
@@ -535,7 +534,7 @@ def timeout_deferred(
 
     deferred.addErrback(convert_cancelled)
 
-    def cancel_timeout(result):
+    def cancel_timeout(result: _T) -> _T:
         # stop the pending call to cancel the deferred if it's been fired
         if delayed_call.active():
             delayed_call.cancel()
@@ -543,11 +542,11 @@ def timeout_deferred(
 
     deferred.addBoth(cancel_timeout)
 
-    def success_cb(val):
+    def success_cb(val: _T) -> None:
         if not new_d.called:
             new_d.callback(val)
 
-    def failure_cb(val):
+    def failure_cb(val: Failure) -> None:
         if not new_d.called:
             new_d.errback(val)
 
@@ -558,13 +557,13 @@ def timeout_deferred(
 
 # This class can't be generic because it uses slots with attrs.
 # See: https://github.com/python-attrs/attrs/issues/313
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class DoneAwaitable:  # should be: Generic[R]
     """Simple awaitable that returns the provided value."""
 
-    value = attr.ib(type=Any)  # should be: R
+    value: Any  # should be: R
 
-    def __await__(self):
+    def __await__(self) -> Any:
         return self
 
     def __iter__(self) -> "DoneAwaitable":