summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-04-19 14:58:23 +0100
committerErik Johnston <erik@matrix.org>2022-04-21 11:49:51 +0100
commitfe0048c2f78257c50dceaadb7672456ecbfc7643 (patch)
tree7c81a2e9864bf9a1608a40ae848db4b0331467a5
parentClarify changelog entry (diff)
downloadsynapse-fe0048c2f78257c50dceaadb7672456ecbfc7643.tar.xz
Add a `AwakenableSleeper` class
-rw-r--r--synapse/util/async_helpers.py57
-rw-r--r--tests/util/test_async_helpers.py38
2 files changed, 95 insertions, 0 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 650e44de22..83830dc4b2 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -734,3 +734,60 @@ def delay_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
     new_deferred: "defer.Deferred[T]" = defer.Deferred(handle_cancel)
     deferred.chainDeferred(new_deferred)
     return new_deferred
+
+
+class AwakenableSleeper:
+    """Allows explicitly waking up deferreds related to an entity that are
+    currently sleeping.
+    """
+
+    def __init__(self, reactor: IReactorTime) -> None:
+        self._streams: Dict[str, Set[defer.Deferred[None]]] = {}
+        self._reactor = reactor
+
+    def wake(self, name: str) -> None:
+        """Wake everything related to `name` that is currently sleeping."""
+        stream_set = self._streams.pop(name, set())
+        for deferred in set(stream_set):
+            try:
+                with PreserveLoggingContext():
+                    deferred.callback(None)
+            except Exception:
+                pass
+
+    async def sleep(self, name: str, delay_ms: int) -> None:
+        """Sleep for the given number of milliseconds, or return if the given
+        `name` is explicitly woken up.
+        """
+
+        # Create a deferred that gets called in N seconds
+        sleep_deferred: "defer.Deferred[None]" = defer.Deferred()
+        call = self._reactor.callLater(delay_ms / 1000, sleep_deferred.callback, None)
+
+        # Create a deferred that will get called if `wake` is called with
+        # the same `name`.
+        stream_set = self._streams.setdefault(name, set())
+        notify_deferred: "defer.Deferred[None]" = defer.Deferred()
+        stream_set.add(notify_deferred)
+
+        try:
+            # Wait for either the delay or for `wake` to be called.
+            await make_deferred_yieldable(
+                defer.DeferredList(
+                    [sleep_deferred, notify_deferred],
+                    fireOnOneCallback=True,
+                    fireOnOneErrback=True,
+                    consumeErrors=True,
+                )
+            )
+        finally:
+            # Clean up the state
+            stream_set.discard(notify_deferred)
+
+            curr_stream_set = self._streams.get(name)
+            if curr_stream_set is not None and len(curr_stream_set) == 0:
+                self._streams.pop(name)
+
+            # Cancel the sleep if we were woken up
+            if call.active():
+                call.cancel()
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index e5bc416de1..4424c57bd3 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -28,6 +28,7 @@ from synapse.logging.context import (
     make_deferred_yieldable,
 )
 from synapse.util.async_helpers import (
+    AwakenableSleeper,
     ObservableDeferred,
     concurrently_execute,
     delay_cancellation,
@@ -35,6 +36,7 @@ from synapse.util.async_helpers import (
     timeout_deferred,
 )
 
+from tests.server import get_clock
 from tests.unittest import TestCase
 
 
@@ -467,3 +469,39 @@ class DelayCancellationTests(TestCase):
         # logging context.
         blocking_d.callback(None)
         self.successResultOf(d)
+
+
+class AwakenableSleeperTests(TestCase):
+    "Tests AwakenableSleeper"
+
+    def test_sleep(self):
+        reactor, _ = get_clock()
+        sleeper = AwakenableSleeper(reactor)
+
+        d = defer.ensureDeferred(sleeper.sleep("name", 1000))
+
+        reactor.pump([0.0])
+        self.assertFalse(d.called)
+
+        reactor.advance(0.5)
+        self.assertFalse(d.called)
+
+        reactor.advance(0.6)
+        self.assertTrue(d.called)
+
+    def test_explicit_wake(self):
+        reactor, _ = get_clock()
+        sleeper = AwakenableSleeper(reactor)
+
+        d = defer.ensureDeferred(sleeper.sleep("name", 1000))
+
+        reactor.pump([0.0])
+        self.assertFalse(d.called)
+
+        reactor.advance(0.5)
+        self.assertFalse(d.called)
+
+        sleeper.wake("name")
+        self.assertTrue(d.called)
+
+        reactor.advance(0.6)