diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 7253ba120f..581dffd8a0 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -13,23 +13,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import collections
import logging
from contextlib import contextmanager
+from typing import Dict, Sequence, Set, Union
from six.moves import range
+import attr
+
from twisted.internet import defer
from twisted.internet.defer import CancelledError
from twisted.python import failure
-from synapse.util import Clock, logcontext, unwrapFirstError
-
-from .logcontext import (
+from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
run_in_background,
)
+from synapse.util import Clock, unwrapFirstError
logger = logging.getLogger(__name__)
@@ -70,6 +73,10 @@ class ObservableDeferred(object):
def errback(f):
object.__setattr__(self, "_result", (False, f))
while self._observers:
+ # This is a little bit of magic to correctly propagate stack
+ # traces when we `await` on one of the observer deferreds.
+ f.value.__failure__ = f
+
try:
# TODO: Handle errors here.
self._observers.pop().errback(f)
@@ -83,11 +90,12 @@ class ObservableDeferred(object):
deferred.addCallbacks(callback, errback)
- def observe(self):
+ def observe(self) -> defer.Deferred:
"""Observe the underlying deferred.
- Can return either a deferred if the underlying deferred is still pending
- (or has failed), or the actual value. Callers may need to use maybeDeferred.
+ This returns a brand new deferred that is resolved when the underlying
+ deferred is resolved. Interacting with the returned deferred does not
+ effect the underdlying deferred.
"""
if not self._result:
d = defer.Deferred()
@@ -95,13 +103,14 @@ class ObservableDeferred(object):
def remove(r):
self._observers.discard(d)
return r
+
d.addBoth(remove)
self._observers.add(d)
return d
else:
success, res = self._result
- return res if success else defer.fail(res)
+ return defer.succeed(res) if success else defer.fail(res)
def observers(self):
return self._observers
@@ -123,7 +132,9 @@ class ObservableDeferred(object):
def __repr__(self):
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
- id(self), self._result, self._deferred,
+ id(self),
+ self._result,
+ self._deferred,
)
@@ -132,9 +143,9 @@ def concurrently_execute(func, args, limit):
the number of concurrent executions.
Args:
- func (func): Function to execute, should return a deferred.
- args (list): List of arguments to pass to func, each invocation of func
- gets a signle argument.
+ func (func): Function to execute, should return a deferred or coroutine.
+ args (Iterable): List of arguments to pass to func, each invocation of func
+ gets a single argument.
limit (int): Maximum number of conccurent executions.
Returns:
@@ -142,18 +153,19 @@ def concurrently_execute(func, args, limit):
"""
it = iter(args)
- @defer.inlineCallbacks
- def _concurrently_execute_inner():
+ async def _concurrently_execute_inner():
try:
while True:
- yield func(next(it))
+ await maybe_awaitable(func(next(it)))
except StopIteration:
pass
- return logcontext.make_deferred_yieldable(defer.gatherResults([
- run_in_background(_concurrently_execute_inner)
- for _ in range(limit)
- ], consumeErrors=True)).addErrback(unwrapFirstError)
+ return make_deferred_yieldable(
+ defer.gatherResults(
+ [run_in_background(_concurrently_execute_inner) for _ in range(limit)],
+ consumeErrors=True,
+ )
+ ).addErrback(unwrapFirstError)
def yieldable_gather_results(func, iter, *args, **kwargs):
@@ -169,10 +181,12 @@ def yieldable_gather_results(func, iter, *args, **kwargs):
Deferred[list]: Resolved when all functions have been invoked, or errors if
one of the function calls fails.
"""
- return logcontext.make_deferred_yieldable(defer.gatherResults([
- run_in_background(func, item, *args, **kwargs)
- for item in iter
- ], consumeErrors=True)).addErrback(unwrapFirstError)
+ return make_deferred_yieldable(
+ defer.gatherResults(
+ [run_in_background(func, item, *args, **kwargs) for item in iter],
+ consumeErrors=True,
+ )
+ ).addErrback(unwrapFirstError)
class Linearizer(object):
@@ -185,6 +199,7 @@ class Linearizer(object):
# do some work.
"""
+
def __init__(self, name=None, max_count=1, clock=None):
"""
Args:
@@ -197,6 +212,7 @@ class Linearizer(object):
if not clock:
from twisted.internet import reactor
+
clock = Clock(reactor)
self._clock = clock
self.max_count = max_count
@@ -205,7 +221,9 @@ class Linearizer(object):
# the first element is the number of things executing, and
# the second element is an OrderedDict, where the keys are deferreds for the
# things blocked from executing.
- self.key_to_defer = {}
+ self.key_to_defer = (
+ {}
+ ) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
def queue(self, key):
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
@@ -221,7 +239,7 @@ class Linearizer(object):
res = self._await_lock(key)
else:
logger.debug(
- "Acquired uncontended linearizer lock %r for key %r", self.name, key,
+ "Acquired uncontended linearizer lock %r for key %r", self.name, key
)
entry[0] += 1
res = defer.succeed(None)
@@ -266,9 +284,7 @@ class Linearizer(object):
"""
entry = self.key_to_defer[key]
- logger.debug(
- "Waiting to acquire linearizer lock %r for key %r", self.name, key,
- )
+ logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
new_defer = make_deferred_yieldable(defer.Deferred())
entry[1][new_defer] = 1
@@ -293,14 +309,14 @@ class Linearizer(object):
logger.info("defer %r got err %r", new_defer, e)
if isinstance(e, CancelledError):
logger.debug(
- "Cancelling wait for linearizer lock %r for key %r",
- self.name, key,
+ "Cancelling wait for linearizer lock %r for key %r", self.name, key
)
else:
- logger.warn(
+ logger.warning(
"Unexpected exception waiting for linearizer lock %r for key %r",
- self.name, key,
+ self.name,
+ key,
)
# we just have to take ourselves back out of the queue.
@@ -334,10 +350,10 @@ class ReadWriteLock(object):
def __init__(self):
# Latest readers queued
- self.key_to_current_readers = {}
+ self.key_to_current_readers = {} # type: Dict[str, Set[defer.Deferred]]
# Latest writer queued
- self.key_to_current_writer = {}
+ self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
@defer.inlineCallbacks
def read(self, key):
@@ -360,7 +376,7 @@ class ReadWriteLock(object):
new_defer.callback(None)
self.key_to_current_readers.get(key, set()).discard(new_defer)
- defer.returnValue(_ctx_manager())
+ return _ctx_manager()
@defer.inlineCallbacks
def write(self, key):
@@ -390,7 +406,7 @@ class ReadWriteLock(object):
if self.key_to_current_writer[key] == new_defer:
self.key_to_current_writer.pop(key)
- defer.returnValue(_ctx_manager())
+ return _ctx_manager()
def _cancelled_to_timed_out_error(value, timeout):
@@ -438,7 +454,7 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
try:
deferred.cancel()
- except: # noqa: E722, if we throw any exception it'll break time outs
+ except: # noqa: E722, if we throw any exception it'll break time outs
logger.exception("Canceller failed during timeout")
if not new_d.called:
@@ -473,3 +489,30 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
deferred.addCallbacks(success_cb, failure_cb)
return new_d
+
+
+@attr.s(slots=True, frozen=True)
+class DoneAwaitable(object):
+ """Simple awaitable that returns the provided value.
+ """
+
+ value = attr.ib()
+
+ def __await__(self):
+ return self
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ raise StopIteration(self.value)
+
+
+def maybe_awaitable(value):
+ """Convert a value to an awaitable if not already an awaitable.
+ """
+
+ if hasattr(value, "__await__"):
+ return value
+
+ return DoneAwaitable(value)
|