diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index f1c46836b1..f7af2bca7f 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -13,12 +13,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import collections
import logging
from contextlib import contextmanager
+from typing import Dict, Sequence, Set, Union
from six.moves import range
+import attr
+
from twisted.internet import defer
from twisted.internet.defer import CancelledError
from twisted.python import failure
@@ -69,6 +73,10 @@ class ObservableDeferred(object):
def errback(f):
object.__setattr__(self, "_result", (False, f))
while self._observers:
+ # This is a little bit of magic to correctly propagate stack
+ # traces when we `await` on one of the observer deferreds.
+ f.value.__failure__ = f
+
try:
# TODO: Handle errors here.
self._observers.pop().errback(f)
@@ -82,11 +90,12 @@ class ObservableDeferred(object):
deferred.addCallbacks(callback, errback)
- def observe(self):
+ def observe(self) -> defer.Deferred:
"""Observe the underlying deferred.
- Can return either a deferred if the underlying deferred is still pending
- (or has failed), or the actual value. Callers may need to use maybeDeferred.
+ This returns a brand new deferred that is resolved when the underlying
+ deferred is resolved. Interacting with the returned deferred does not
+ effect the underdlying deferred.
"""
if not self._result:
d = defer.Deferred()
@@ -101,7 +110,7 @@ class ObservableDeferred(object):
return d
else:
success, res = self._result
- return res if success else defer.fail(res)
+ return defer.succeed(res) if success else defer.fail(res)
def observers(self):
return self._observers
@@ -134,9 +143,9 @@ def concurrently_execute(func, args, limit):
the number of concurrent executions.
Args:
- func (func): Function to execute, should return a deferred.
- args (list): List of arguments to pass to func, each invocation of func
- gets a signle argument.
+ func (func): Function to execute, should return a deferred or coroutine.
+ args (Iterable): List of arguments to pass to func, each invocation of func
+ gets a single argument.
limit (int): Maximum number of conccurent executions.
Returns:
@@ -144,11 +153,10 @@ def concurrently_execute(func, args, limit):
"""
it = iter(args)
- @defer.inlineCallbacks
- def _concurrently_execute_inner():
+ async def _concurrently_execute_inner():
try:
while True:
- yield func(next(it))
+ await maybe_awaitable(func(next(it)))
except StopIteration:
pass
@@ -213,7 +221,21 @@ class Linearizer(object):
# the first element is the number of things executing, and
# the second element is an OrderedDict, where the keys are deferreds for the
# things blocked from executing.
- self.key_to_defer = {}
+ self.key_to_defer = (
+ {}
+ ) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
+
+ def is_queued(self, key) -> bool:
+ """Checks whether there is a process queued up waiting
+ """
+ entry = self.key_to_defer.get(key)
+ if not entry:
+ # No entry so nothing is waiting.
+ return False
+
+ # There are waiting deferreds only in the OrderedDict of deferreds is
+ # non-empty.
+ return bool(entry[1])
def queue(self, key):
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
@@ -303,7 +325,7 @@ class Linearizer(object):
)
else:
- logger.warn(
+ logger.warning(
"Unexpected exception waiting for linearizer lock %r for key %r",
self.name,
key,
@@ -340,10 +362,10 @@ class ReadWriteLock(object):
def __init__(self):
# Latest readers queued
- self.key_to_current_readers = {}
+ self.key_to_current_readers = {} # type: Dict[str, Set[defer.Deferred]]
# Latest writer queued
- self.key_to_current_writer = {}
+ self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
@defer.inlineCallbacks
def read(self, key):
@@ -479,3 +501,30 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
deferred.addCallbacks(success_cb, failure_cb)
return new_d
+
+
+@attr.s(slots=True, frozen=True)
+class DoneAwaitable(object):
+ """Simple awaitable that returns the provided value.
+ """
+
+ value = attr.ib()
+
+ def __await__(self):
+ return self
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ raise StopIteration(self.value)
+
+
+def maybe_awaitable(value):
+ """Convert a value to an awaitable if not already an awaitable.
+ """
+
+ if hasattr(value, "__await__"):
+ return value
+
+ return DoneAwaitable(value)
|