summary refs log tree commit diff
path: root/synapse/util/async_helpers.py
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2020-06-10 11:42:30 +0100
committerBrendan Abolivier <babolivier@matrix.org>2020-06-10 11:42:30 +0100
commitec0a7b9034806d6b2ba086bae58f5c6b0fd14672 (patch)
treef2af547b1342795e10548f8fb7a9cfc93e03df37 /synapse/util/async_helpers.py
parentchangelog (diff)
parent1.15.0rc1 (diff)
downloadsynapse-ec0a7b9034806d6b2ba086bae58f5c6b0fd14672.tar.xz
Merge branch 'develop' into babolivier/mark_unread
Diffstat (limited to 'synapse/util/async_helpers.py')
-rw-r--r--synapse/util/async_helpers.py77
1 files changed, 63 insertions, 14 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index f1c46836b1..f7af2bca7f 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -13,12 +13,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import collections
 import logging
 from contextlib import contextmanager
+from typing import Dict, Sequence, Set, Union
 
 from six.moves import range
 
+import attr
+
 from twisted.internet import defer
 from twisted.internet.defer import CancelledError
 from twisted.python import failure
@@ -69,6 +73,10 @@ class ObservableDeferred(object):
         def errback(f):
             object.__setattr__(self, "_result", (False, f))
             while self._observers:
+                # This is a little bit of magic to correctly propagate stack
+                # traces when we `await` on one of the observer deferreds.
+                f.value.__failure__ = f
+
                 try:
                     # TODO: Handle errors here.
                     self._observers.pop().errback(f)
@@ -82,11 +90,12 @@ class ObservableDeferred(object):
 
         deferred.addCallbacks(callback, errback)
 
-    def observe(self):
+    def observe(self) -> defer.Deferred:
         """Observe the underlying deferred.
 
-        Can return either a deferred if the underlying deferred is still pending
-        (or has failed), or the actual value. Callers may need to use maybeDeferred.
+        This returns a brand new deferred that is resolved when the underlying
+        deferred is resolved. Interacting with the returned deferred does not
+        effect the underdlying deferred.
         """
         if not self._result:
             d = defer.Deferred()
@@ -101,7 +110,7 @@ class ObservableDeferred(object):
             return d
         else:
             success, res = self._result
-            return res if success else defer.fail(res)
+            return defer.succeed(res) if success else defer.fail(res)
 
     def observers(self):
         return self._observers
@@ -134,9 +143,9 @@ def concurrently_execute(func, args, limit):
     the number of concurrent executions.
 
     Args:
-        func (func): Function to execute, should return a deferred.
-        args (list): List of arguments to pass to func, each invocation of func
-            gets a signle argument.
+        func (func): Function to execute, should return a deferred or coroutine.
+        args (Iterable): List of arguments to pass to func, each invocation of func
+            gets a single argument.
         limit (int): Maximum number of conccurent executions.
 
     Returns:
@@ -144,11 +153,10 @@ def concurrently_execute(func, args, limit):
     """
     it = iter(args)
 
-    @defer.inlineCallbacks
-    def _concurrently_execute_inner():
+    async def _concurrently_execute_inner():
         try:
             while True:
-                yield func(next(it))
+                await maybe_awaitable(func(next(it)))
         except StopIteration:
             pass
 
@@ -213,7 +221,21 @@ class Linearizer(object):
         # the first element is the number of things executing, and
         # the second element is an OrderedDict, where the keys are deferreds for the
         # things blocked from executing.
-        self.key_to_defer = {}
+        self.key_to_defer = (
+            {}
+        )  # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
+
+    def is_queued(self, key) -> bool:
+        """Checks whether there is a process queued up waiting
+        """
+        entry = self.key_to_defer.get(key)
+        if not entry:
+            # No entry so nothing is waiting.
+            return False
+
+        # There are waiting deferreds only in the OrderedDict of deferreds is
+        # non-empty.
+        return bool(entry[1])
 
     def queue(self, key):
         # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
@@ -303,7 +325,7 @@ class Linearizer(object):
                 )
 
             else:
-                logger.warn(
+                logger.warning(
                     "Unexpected exception waiting for linearizer lock %r for key %r",
                     self.name,
                     key,
@@ -340,10 +362,10 @@ class ReadWriteLock(object):
 
     def __init__(self):
         # Latest readers queued
-        self.key_to_current_readers = {}
+        self.key_to_current_readers = {}  # type: Dict[str, Set[defer.Deferred]]
 
         # Latest writer queued
-        self.key_to_current_writer = {}
+        self.key_to_current_writer = {}  # type: Dict[str, defer.Deferred]
 
     @defer.inlineCallbacks
     def read(self, key):
@@ -479,3 +501,30 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
     deferred.addCallbacks(success_cb, failure_cb)
 
     return new_d
+
+
+@attr.s(slots=True, frozen=True)
+class DoneAwaitable(object):
+    """Simple awaitable that returns the provided value.
+    """
+
+    value = attr.ib()
+
+    def __await__(self):
+        return self
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        raise StopIteration(self.value)
+
+
+def maybe_awaitable(value):
+    """Convert a value to an awaitable if not already an awaitable.
+    """
+
+    if hasattr(value, "__await__"):
+        return value
+
+    return DoneAwaitable(value)