summary refs log tree commit diff
path: root/synapse/util/async_helpers.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-09-08 16:50:51 -0400
committerGitHub <noreply@github.com>2020-09-08 16:50:51 -0400
commite45b834119468272816c6558ebadb5cc286f148b (patch)
tree9c5d7fc58393e61a89e37e8ff9f123d2654d069c /synapse/util/async_helpers.py
parentFix mypy error on develop (#8282) (diff)
downloadsynapse-e45b834119468272816c6558ebadb5cc286f148b.tar.xz
Add types to async_helpers (#8260)
Diffstat (limited to '')
-rw-r--r--synapse/util/async_helpers.py135
1 files changed, 85 insertions, 50 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index bb57e27beb..67ce9a5f39 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -17,13 +17,25 @@
 import collections
 import logging
 from contextlib import contextmanager
-from typing import Dict, Sequence, Set, Union
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Hashable,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    TypeVar,
+    Union,
+)
 
 import attr
 from typing_extensions import ContextManager
 
 from twisted.internet import defer
 from twisted.internet.defer import CancelledError
+from twisted.internet.interfaces import IReactorTime
 from twisted.python import failure
 
 from synapse.logging.context import (
@@ -54,7 +66,7 @@ class ObservableDeferred:
 
     __slots__ = ["_deferred", "_observers", "_result"]
 
-    def __init__(self, deferred, consumeErrors=False):
+    def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False):
         object.__setattr__(self, "_deferred", deferred)
         object.__setattr__(self, "_result", None)
         object.__setattr__(self, "_observers", set())
@@ -111,25 +123,25 @@ class ObservableDeferred:
             success, res = self._result
             return defer.succeed(res) if success else defer.fail(res)
 
-    def observers(self):
+    def observers(self) -> List[defer.Deferred]:
         return self._observers
 
-    def has_called(self):
+    def has_called(self) -> bool:
         return self._result is not None
 
-    def has_succeeded(self):
+    def has_succeeded(self) -> bool:
         return self._result is not None and self._result[0] is True
 
-    def get_result(self):
+    def get_result(self) -> Any:
         return self._result[1]
 
-    def __getattr__(self, name):
+    def __getattr__(self, name: str) -> Any:
         return getattr(self._deferred, name)
 
-    def __setattr__(self, name, value):
+    def __setattr__(self, name: str, value: Any) -> None:
         setattr(self._deferred, name, value)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
             id(self),
             self._result,
@@ -137,18 +149,20 @@ class ObservableDeferred:
         )
 
 
-def concurrently_execute(func, args, limit):
-    """Executes the function with each argument conncurrently while limiting
+def concurrently_execute(
+    func: Callable, args: Iterable[Any], limit: int
+) -> defer.Deferred:
+    """Executes the function with each argument concurrently while limiting
     the number of concurrent executions.
 
     Args:
-        func (func): Function to execute, should return a deferred or coroutine.
-        args (Iterable): List of arguments to pass to func, each invocation of func
+        func: Function to execute, should return a deferred or coroutine.
+        args: List of arguments to pass to func, each invocation of func
             gets a single argument.
-        limit (int): Maximum number of conccurent executions.
+        limit: Maximum number of conccurent executions.
 
     Returns:
-        deferred: Resolved when all function invocations have finished.
+        Deferred[list]: Resolved when all function invocations have finished.
     """
     it = iter(args)
 
@@ -167,14 +181,17 @@ def concurrently_execute(func, args, limit):
     ).addErrback(unwrapFirstError)
 
 
-def yieldable_gather_results(func, iter, *args, **kwargs):
+def yieldable_gather_results(
+    func: Callable, iter: Iterable, *args: Any, **kwargs: Any
+) -> defer.Deferred:
     """Executes the function with each argument concurrently.
 
     Args:
-        func (func): Function to execute that returns a Deferred
-        iter (iter): An iterable that yields items that get passed as the first
+        func: Function to execute that returns a Deferred
+        iter: An iterable that yields items that get passed as the first
             argument to the function
         *args: Arguments to be passed to each call to func
+        **kwargs: Keyword arguments to be passed to each call to func
 
     Returns
         Deferred[list]: Resolved when all functions have been invoked, or errors if
@@ -188,24 +205,37 @@ def yieldable_gather_results(func, iter, *args, **kwargs):
     ).addErrback(unwrapFirstError)
 
 
+@attr.s(slots=True)
+class _LinearizerEntry:
+    # The number of things executing.
+    count = attr.ib(type=int)
+    # Deferreds for the things blocked from executing.
+    deferreds = attr.ib(type=collections.OrderedDict)
+
+
 class Linearizer:
     """Limits concurrent access to resources based on a key. Useful to ensure
     only a few things happen at a time on a given resource.
 
     Example:
 
-        with (yield limiter.queue("test_key")):
+        with await limiter.queue("test_key"):
             # do some work.
 
     """
 
-    def __init__(self, name=None, max_count=1, clock=None):
+    def __init__(
+        self,
+        name: Optional[str] = None,
+        max_count: int = 1,
+        clock: Optional[Clock] = None,
+    ):
         """
         Args:
-            max_count(int): The maximum number of concurrent accesses
+            max_count: The maximum number of concurrent accesses
         """
         if name is None:
-            self.name = id(self)
+            self.name = id(self)  # type: Union[str, int]
         else:
             self.name = name
 
@@ -216,15 +246,10 @@ class Linearizer:
         self._clock = clock
         self.max_count = max_count
 
-        # key_to_defer is a map from the key to a 2 element list where
-        # the first element is the number of things executing, and
-        # the second element is an OrderedDict, where the keys are deferreds for the
-        # things blocked from executing.
-        self.key_to_defer = (
-            {}
-        )  # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
+        # key_to_defer is a map from the key to a _LinearizerEntry.
+        self.key_to_defer = {}  # type: Dict[Hashable, _LinearizerEntry]
 
-    def is_queued(self, key) -> bool:
+    def is_queued(self, key: Hashable) -> bool:
         """Checks whether there is a process queued up waiting
         """
         entry = self.key_to_defer.get(key)
@@ -234,25 +259,27 @@ class Linearizer:
 
         # There are waiting deferreds only in the OrderedDict of deferreds is
         # non-empty.
-        return bool(entry[1])
+        return bool(entry.deferreds)
 
-    def queue(self, key):
+    def queue(self, key: Hashable) -> defer.Deferred:
         # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
         # (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not
         # propagated inside inlineCallbacks until Twisted 18.7)
-        entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()])
+        entry = self.key_to_defer.setdefault(
+            key, _LinearizerEntry(0, collections.OrderedDict())
+        )
 
         # If the number of things executing is greater than the maximum
         # then add a deferred to the list of blocked items
         # When one of the things currently executing finishes it will callback
         # this item so that it can continue executing.
-        if entry[0] >= self.max_count:
+        if entry.count >= self.max_count:
             res = self._await_lock(key)
         else:
             logger.debug(
                 "Acquired uncontended linearizer lock %r for key %r", self.name, key
             )
-            entry[0] += 1
+            entry.count += 1
             res = defer.succeed(None)
 
         # once we successfully get the lock, we need to return a context manager which
@@ -267,15 +294,15 @@ class Linearizer:
 
                 # We've finished executing so check if there are any things
                 # blocked waiting to execute and start one of them
-                entry[0] -= 1
+                entry.count -= 1
 
-                if entry[1]:
-                    (next_def, _) = entry[1].popitem(last=False)
+                if entry.deferreds:
+                    (next_def, _) = entry.deferreds.popitem(last=False)
 
                     # we need to run the next thing in the sentinel context.
                     with PreserveLoggingContext():
                         next_def.callback(None)
-                elif entry[0] == 0:
+                elif entry.count == 0:
                     # We were the last thing for this key: remove it from the
                     # map.
                     del self.key_to_defer[key]
@@ -283,7 +310,7 @@ class Linearizer:
         res.addCallback(_ctx_manager)
         return res
 
-    def _await_lock(self, key):
+    def _await_lock(self, key: Hashable) -> defer.Deferred:
         """Helper for queue: adds a deferred to the queue
 
         Assumes that we've already checked that we've reached the limit of the number
@@ -298,11 +325,11 @@ class Linearizer:
         logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
 
         new_defer = make_deferred_yieldable(defer.Deferred())
-        entry[1][new_defer] = 1
+        entry.deferreds[new_defer] = 1
 
         def cb(_r):
             logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
-            entry[0] += 1
+            entry.count += 1
 
             # if the code holding the lock completes synchronously, then it
             # will recursively run the next claimant on the list. That can
@@ -331,7 +358,7 @@ class Linearizer:
                 )
 
             # we just have to take ourselves back out of the queue.
-            del entry[1][new_defer]
+            del entry.deferreds[new_defer]
             return e
 
         new_defer.addCallbacks(cb, eb)
@@ -419,14 +446,22 @@ class ReadWriteLock:
         return _ctx_manager()
 
 
-def _cancelled_to_timed_out_error(value, timeout):
+R = TypeVar("R")
+
+
+def _cancelled_to_timed_out_error(value: R, timeout: float) -> R:
     if isinstance(value, failure.Failure):
         value.trap(CancelledError)
         raise defer.TimeoutError(timeout, "Deferred")
     return value
 
 
-def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
+def timeout_deferred(
+    deferred: defer.Deferred,
+    timeout: float,
+    reactor: IReactorTime,
+    on_timeout_cancel: Optional[Callable[[Any, float], Any]] = None,
+) -> defer.Deferred:
     """The in built twisted `Deferred.addTimeout` fails to time out deferreds
     that have a canceller that throws exceptions. This method creates a new
     deferred that wraps and times out the given deferred, correctly handling
@@ -437,10 +472,10 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
     NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred
 
     Args:
-        deferred (Deferred)
-        timeout (float): Timeout in seconds
-        reactor (twisted.interfaces.IReactorTime): The twisted reactor to use
-        on_timeout_cancel (callable): A callable which is called immediately
+        deferred: The Deferred to potentially timeout.
+        timeout: Timeout in seconds
+        reactor: The twisted reactor to use
+        on_timeout_cancel: A callable which is called immediately
             after the deferred times out, and not if this deferred is
             otherwise cancelled before the timeout.
 
@@ -452,7 +487,7 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
             CancelledError Failure into a defer.TimeoutError.
 
     Returns:
-        Deferred
+        A new Deferred.
     """
 
     new_d = defer.Deferred()