summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12353.misc1
-rw-r--r--tests/util/test_linearizer.py224
2 files changed, 134 insertions, 91 deletions
diff --git a/changelog.d/12353.misc b/changelog.d/12353.misc
new file mode 100644
index 0000000000..1d681fb0e3
--- /dev/null
+++ b/changelog.d/12353.misc
@@ -0,0 +1 @@
+Convert `Linearizer` tests from `inlineCallbacks` to async.
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index c4a3917b23..fa132391a1 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -13,160 +13,202 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Callable, Hashable, Tuple
+
 from twisted.internet import defer, reactor
-from twisted.internet.defer import CancelledError
+from twisted.internet.base import ReactorBase
+from twisted.internet.defer import CancelledError, Deferred
 
 from synapse.logging.context import LoggingContext, current_context
-from synapse.util import Clock
 from synapse.util.async_helpers import Linearizer
 
 from tests import unittest
 
 
 class LinearizerTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
-    def test_linearizer(self):
+    def _start_task(
+        self, linearizer: Linearizer, key: Hashable
+    ) -> Tuple["Deferred[None]", "Deferred[None]", Callable[[], None]]:
+        """Starts a task which acquires the linearizer lock, blocks, then completes.
+
+        Args:
+            linearizer: The `Linearizer`.
+            key: The `Linearizer` key.
+
+        Returns:
+            A tuple containing:
+             * A cancellable `Deferred` for the entire task.
+             * A `Deferred` that resolves once the task acquires the lock.
+             * A function that unblocks the task. Must be called by the caller
+               to allow the task to release the lock and complete.
+        """
+        acquired_d: "Deferred[None]" = Deferred()
+        unblock_d: "Deferred[None]" = Deferred()
+
+        async def task() -> None:
+            with await linearizer.queue(key):
+                acquired_d.callback(None)
+                await unblock_d
+
+        d = defer.ensureDeferred(task())
+
+        def unblock() -> None:
+            unblock_d.callback(None)
+            # The next task, if it exists, will acquire the lock and require a kick of
+            # the reactor to advance.
+            self._pump()
+
+        return d, acquired_d, unblock
+
+    def _pump(self) -> None:
+        """Pump the reactor to advance `Linearizer`s."""
+        assert isinstance(reactor, ReactorBase)
+        while reactor.getDelayedCalls():
+            reactor.runUntilCurrent()
+
+    def test_linearizer(self) -> None:
+        """Tests that a task is queued up behind an earlier task."""
         linearizer = Linearizer()
 
         key = object()
 
-        d1 = linearizer.queue(key)
-        cm1 = yield d1
+        _, acquired_d1, unblock1 = self._start_task(linearizer, key)
+        self.assertTrue(acquired_d1.called)
+
+        _, acquired_d2, unblock2 = self._start_task(linearizer, key)
+        self.assertFalse(acquired_d2.called)
 
-        d2 = linearizer.queue(key)
-        self.assertFalse(d2.called)
+        # Once the first task is done, the second task can continue.
+        unblock1()
+        self.assertTrue(acquired_d2.called)
 
-        with cm1:
-            self.assertFalse(d2.called)
+        unblock2()
 
-        with (yield d2):
-            pass
+    def test_linearizer_is_queued(self) -> None:
+        """Tests `Linearizer.is_queued`.
 
-    @defer.inlineCallbacks
-    def test_linearizer_is_queued(self):
+        Runs through the same scenario as `test_linearizer`.
+        """
         linearizer = Linearizer()
 
         key = object()
 
-        d1 = linearizer.queue(key)
-        cm1 = yield d1
+        _, acquired_d1, unblock1 = self._start_task(linearizer, key)
+        self.assertTrue(acquired_d1.called)
 
-        # Since d1 gets called immediately, "is_queued" should return false.
+        # Since the first task acquires the lock immediately, "is_queued" should return
+        # false.
         self.assertFalse(linearizer.is_queued(key))
 
-        d2 = linearizer.queue(key)
-        self.assertFalse(d2.called)
+        _, acquired_d2, unblock2 = self._start_task(linearizer, key)
+        self.assertFalse(acquired_d2.called)
 
-        # Now d2 is queued up behind successful completion of cm1
+        # Now the second task is queued up behind the first.
         self.assertTrue(linearizer.is_queued(key))
 
-        with cm1:
-            self.assertFalse(d2.called)
-
-            # cm1 still not done, so d2 still queued.
-            self.assertTrue(linearizer.is_queued(key))
+        unblock1()
 
-        # And now d2 is called and nothing is in the queue again
+        # And now the second task acquires the lock and nothing is in the queue again.
+        self.assertTrue(acquired_d2.called)
         self.assertFalse(linearizer.is_queued(key))
 
-        with (yield d2):
-            self.assertFalse(linearizer.is_queued(key))
-
+        unblock2()
         self.assertFalse(linearizer.is_queued(key))
 
-    def test_lots_of_queued_things(self):
-        # we have one slow thing, and lots of fast things queued up behind it.
-        # it should *not* explode the stack.
+    def test_lots_of_queued_things(self) -> None:
+        """Tests lots of fast things queued up behind a slow thing.
+
+        The stack should *not* explode when the slow thing completes.
+        """
         linearizer = Linearizer()
+        key = ""
 
-        @defer.inlineCallbacks
-        def func(i, sleep=False):
+        async def func(i: int) -> None:
             with LoggingContext("func(%s)" % i) as lc:
-                with (yield linearizer.queue("")):
+                with (await linearizer.queue(key)):
                     self.assertEqual(current_context(), lc)
-                    if sleep:
-                        yield Clock(reactor).sleep(0)
 
                 self.assertEqual(current_context(), lc)
 
-        func(0, sleep=True)
+        _, _, unblock = self._start_task(linearizer, key)
         for i in range(1, 100):
-            func(i)
+            defer.ensureDeferred(func(i))
 
-        return func(1000)
+        d = defer.ensureDeferred(func(1000))
+        unblock()
+        self.successResultOf(d)
 
-    @defer.inlineCallbacks
-    def test_multiple_entries(self):
+    def test_multiple_entries(self) -> None:
+        """Tests a `Linearizer` with a concurrency above 1."""
         limiter = Linearizer(max_count=3)
 
         key = object()
 
-        d1 = limiter.queue(key)
-        cm1 = yield d1
-
-        d2 = limiter.queue(key)
-        cm2 = yield d2
-
-        d3 = limiter.queue(key)
-        cm3 = yield d3
-
-        d4 = limiter.queue(key)
-        self.assertFalse(d4.called)
-
-        d5 = limiter.queue(key)
-        self.assertFalse(d5.called)
+        _, acquired_d1, unblock1 = self._start_task(limiter, key)
+        self.assertTrue(acquired_d1.called)
 
-        with cm1:
-            self.assertFalse(d4.called)
-            self.assertFalse(d5.called)
+        _, acquired_d2, unblock2 = self._start_task(limiter, key)
+        self.assertTrue(acquired_d2.called)
 
-        cm4 = yield d4
-        self.assertFalse(d5.called)
+        _, acquired_d3, unblock3 = self._start_task(limiter, key)
+        self.assertTrue(acquired_d3.called)
 
-        with cm3:
-            self.assertFalse(d5.called)
+        # These next two tasks have to wait.
+        _, acquired_d4, unblock4 = self._start_task(limiter, key)
+        self.assertFalse(acquired_d4.called)
 
-        cm5 = yield d5
+        _, acquired_d5, unblock5 = self._start_task(limiter, key)
+        self.assertFalse(acquired_d5.called)
 
-        with cm2:
-            pass
+        # Once the first task completes, the fourth task can continue.
+        unblock1()
+        self.assertTrue(acquired_d4.called)
+        self.assertFalse(acquired_d5.called)
 
-        with cm4:
-            pass
+        # Once the third task completes, the fifth task can continue.
+        unblock3()
+        self.assertTrue(acquired_d5.called)
 
-        with cm5:
-            pass
+        # Make all tasks finish.
+        unblock2()
+        unblock4()
+        unblock5()
 
-        d6 = limiter.queue(key)
-        with (yield d6):
-            pass
+        # The next task shouldn't have to wait.
+        _, acquired_d6, unblock6 = self._start_task(limiter, key)
+        self.assertTrue(acquired_d6)
+        unblock6()
 
-    @defer.inlineCallbacks
-    def test_cancellation(self):
+    def test_cancellation(self) -> None:
+        """Tests cancellation while waiting for a `Linearizer`."""
         linearizer = Linearizer()
 
         key = object()
 
-        d1 = linearizer.queue(key)
-        cm1 = yield d1
+        d1, acquired_d1, unblock1 = self._start_task(linearizer, key)
+        self.assertTrue(acquired_d1.called)
 
-        d2 = linearizer.queue(key)
-        self.assertFalse(d2.called)
+        # Create a second task, waiting for the first task.
+        d2, acquired_d2, _ = self._start_task(linearizer, key)
+        self.assertFalse(acquired_d2.called)
 
-        d3 = linearizer.queue(key)
-        self.assertFalse(d3.called)
+        # Create a third task, waiting for the second task.
+        d3, acquired_d3, unblock3 = self._start_task(linearizer, key)
+        self.assertFalse(acquired_d3.called)
 
+        # Cancel the waiting second task.
         d2.cancel()
 
-        with cm1:
-            pass
+        unblock1()
+        self.successResultOf(d1)
 
         self.assertTrue(d2.called)
-        try:
-            yield d2
-            self.fail("Expected d2 to raise CancelledError")
-        except CancelledError:
-            pass
-
-        with (yield d3):
-            pass
+        self.failureResultOf(d2, CancelledError)
+
+        # The third task should continue running.
+        self.assertTrue(
+            acquired_d3.called,
+            "Third task did not get the lock after the second task was cancelled",
+        )
+        unblock3()
+        self.successResultOf(d3)