summary refs log tree commit diff
path: root/synapse/util/async_helpers.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/async_helpers.py')
-rw-r--r--synapse/util/async_helpers.py79
1 files changed, 77 insertions, 2 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 20ce294209..150a04b53e 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import abc
 import collections
 import inspect
 import itertools
@@ -30,9 +31,11 @@ from typing import (
     Iterator,
     Optional,
     Set,
+    Tuple,
     TypeVar,
     Union,
     cast,
+    overload,
 )
 
 import attr
@@ -55,7 +58,26 @@ logger = logging.getLogger(__name__)
 _T = TypeVar("_T")
 
 
-class ObservableDeferred(Generic[_T]):
+class AbstractObservableDeferred(Generic[_T], metaclass=abc.ABCMeta):
+    """Abstract base class defining the consumer interface of ObservableDeferred"""
+
+    __slots__ = ()
+
+    @abc.abstractmethod
+    def observe(self) -> "defer.Deferred[_T]":
+        """Add a new observer for this ObservableDeferred
+
+        This returns a brand new deferred that is resolved when the underlying
+        deferred is resolved. Interacting with the returned deferred does not
+        effect the underlying deferred.
+
+        Note that the returned Deferred doesn't follow the Synapse logcontext rules -
+        you will probably want to `make_deferred_yieldable` it.
+        """
+        ...
+
+
+class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
     """Wraps a deferred object so that we can add observer deferreds. These
     observer deferreds do not affect the callback chain of the original
     deferred.
@@ -234,6 +256,59 @@ def yieldable_gather_results(
     ).addErrback(unwrapFirstError)
 
 
+T1 = TypeVar("T1")
+T2 = TypeVar("T2")
+T3 = TypeVar("T3")
+
+
+@overload
+def gather_results(
+    deferredList: Tuple[()], consumeErrors: bool = ...
+) -> "defer.Deferred[Tuple[()]]":
+    ...
+
+
+@overload
+def gather_results(
+    deferredList: Tuple["defer.Deferred[T1]"],
+    consumeErrors: bool = ...,
+) -> "defer.Deferred[Tuple[T1]]":
+    ...
+
+
+@overload
+def gather_results(
+    deferredList: Tuple["defer.Deferred[T1]", "defer.Deferred[T2]"],
+    consumeErrors: bool = ...,
+) -> "defer.Deferred[Tuple[T1, T2]]":
+    ...
+
+
+@overload
+def gather_results(
+    deferredList: Tuple[
+        "defer.Deferred[T1]", "defer.Deferred[T2]", "defer.Deferred[T3]"
+    ],
+    consumeErrors: bool = ...,
+) -> "defer.Deferred[Tuple[T1, T2, T3]]":
+    ...
+
+
+def gather_results(  # type: ignore[misc]
+    deferredList: Tuple["defer.Deferred[T1]", ...],
+    consumeErrors: bool = False,
+) -> "defer.Deferred[Tuple[T1, ...]]":
+    """Combines a tuple of `Deferred`s into a single `Deferred`.
+
+    Wraps `defer.gatherResults` to provide type annotations that support heterogenous
+    lists of `Deferred`s.
+    """
+    # The `type: ignore[misc]` above suppresses
+    # "Overloaded function implementation cannot produce return type of signature 1/2/3"
+    deferred = defer.gatherResults(deferredList, consumeErrors=consumeErrors)
+    return deferred.addCallback(tuple)
+
+
 @attr.s(slots=True)
 class _LinearizerEntry:
     # The number of things executing.
@@ -352,7 +427,7 @@ class Linearizer:
 
         logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
 
-        new_defer = make_deferred_yieldable(defer.Deferred())
+        new_defer: "defer.Deferred[None]" = make_deferred_yieldable(defer.Deferred())
         entry.deferreds[new_defer] = 1
 
         def cb(_r: None) -> "defer.Deferred[None]":