diff options
author | Erik Johnston <erik@matrix.org> | 2022-04-19 14:58:23 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2022-04-21 11:49:51 +0100 |
commit | fe0048c2f78257c50dceaadb7672456ecbfc7643 (patch) | |
tree | 7c81a2e9864bf9a1608a40ae848db4b0331467a5 | |
parent | Clarify changelog entry (diff) | |
download | synapse-fe0048c2f78257c50dceaadb7672456ecbfc7643.tar.xz |
Add a `AwakenableSleeper` class
-rw-r--r-- | synapse/util/async_helpers.py | 57 | ||||
-rw-r--r-- | tests/util/test_async_helpers.py | 38 |
2 files changed, 95 insertions, 0 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 650e44de22..83830dc4b2 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -734,3 +734,60 @@ def delay_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": new_deferred: "defer.Deferred[T]" = defer.Deferred(handle_cancel) deferred.chainDeferred(new_deferred) return new_deferred + + +class AwakenableSleeper: + """Allows explicitly waking up deferreds related to an entity that are + currently sleeping. + """ + + def __init__(self, reactor: IReactorTime) -> None: + self._streams: Dict[str, Set[defer.Deferred[None]]] = {} + self._reactor = reactor + + def wake(self, name: str) -> None: + """Wake everything related to `name` that is currently sleeping.""" + stream_set = self._streams.pop(name, set()) + for deferred in set(stream_set): + try: + with PreserveLoggingContext(): + deferred.callback(None) + except Exception: + pass + + async def sleep(self, name: str, delay_ms: int) -> None: + """Sleep for the given number of milliseconds, or return if the given + `name` is explicitly woken up. + """ + + # Create a deferred that gets called in N seconds + sleep_deferred: "defer.Deferred[None]" = defer.Deferred() + call = self._reactor.callLater(delay_ms / 1000, sleep_deferred.callback, None) + + # Create a deferred that will get called if `wake` is called with + # the same `name`. + stream_set = self._streams.setdefault(name, set()) + notify_deferred: "defer.Deferred[None]" = defer.Deferred() + stream_set.add(notify_deferred) + + try: + # Wait for either the delay or for `wake` to be called. + await make_deferred_yieldable( + defer.DeferredList( + [sleep_deferred, notify_deferred], + fireOnOneCallback=True, + fireOnOneErrback=True, + consumeErrors=True, + ) + ) + finally: + # Clean up the state + stream_set.discard(notify_deferred) + + curr_stream_set = self._streams.get(name) + if curr_stream_set is not None and len(curr_stream_set) == 0: + self._streams.pop(name) + + # Cancel the sleep if we were woken up + if call.active(): + call.cancel() diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index e5bc416de1..4424c57bd3 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -28,6 +28,7 @@ from synapse.logging.context import ( make_deferred_yieldable, ) from synapse.util.async_helpers import ( + AwakenableSleeper, ObservableDeferred, concurrently_execute, delay_cancellation, @@ -35,6 +36,7 @@ from synapse.util.async_helpers import ( timeout_deferred, ) +from tests.server import get_clock from tests.unittest import TestCase @@ -467,3 +469,39 @@ class DelayCancellationTests(TestCase): # logging context. blocking_d.callback(None) self.successResultOf(d) + + +class AwakenableSleeperTests(TestCase): + "Tests AwakenableSleeper" + + def test_sleep(self): + reactor, _ = get_clock() + sleeper = AwakenableSleeper(reactor) + + d = defer.ensureDeferred(sleeper.sleep("name", 1000)) + + reactor.pump([0.0]) + self.assertFalse(d.called) + + reactor.advance(0.5) + self.assertFalse(d.called) + + reactor.advance(0.6) + self.assertTrue(d.called) + + def test_explicit_wake(self): + reactor, _ = get_clock() + sleeper = AwakenableSleeper(reactor) + + d = defer.ensureDeferred(sleeper.sleep("name", 1000)) + + reactor.pump([0.0]) + self.assertFalse(d.called) + + reactor.advance(0.5) + self.assertFalse(d.called) + + sleeper.wake("name") + self.assertTrue(d.called) + + reactor.advance(0.6) |