diff options
Diffstat (limited to 'tests/util')
-rw-r--r-- | tests/util/test_linearizer.py | 224 |
1 files changed, 133 insertions, 91 deletions
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index c4a3917b23..fa132391a1 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -13,160 +13,202 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable, Hashable, Tuple + from twisted.internet import defer, reactor -from twisted.internet.defer import CancelledError +from twisted.internet.base import ReactorBase +from twisted.internet.defer import CancelledError, Deferred from synapse.logging.context import LoggingContext, current_context -from synapse.util import Clock from synapse.util.async_helpers import Linearizer from tests import unittest class LinearizerTestCase(unittest.TestCase): - @defer.inlineCallbacks - def test_linearizer(self): + def _start_task( + self, linearizer: Linearizer, key: Hashable + ) -> Tuple["Deferred[None]", "Deferred[None]", Callable[[], None]]: + """Starts a task which acquires the linearizer lock, blocks, then completes. + + Args: + linearizer: The `Linearizer`. + key: The `Linearizer` key. + + Returns: + A tuple containing: + * A cancellable `Deferred` for the entire task. + * A `Deferred` that resolves once the task acquires the lock. + * A function that unblocks the task. Must be called by the caller + to allow the task to release the lock and complete. + """ + acquired_d: "Deferred[None]" = Deferred() + unblock_d: "Deferred[None]" = Deferred() + + async def task() -> None: + with await linearizer.queue(key): + acquired_d.callback(None) + await unblock_d + + d = defer.ensureDeferred(task()) + + def unblock() -> None: + unblock_d.callback(None) + # The next task, if it exists, will acquire the lock and require a kick of + # the reactor to advance. + self._pump() + + return d, acquired_d, unblock + + def _pump(self) -> None: + """Pump the reactor to advance `Linearizer`s.""" + assert isinstance(reactor, ReactorBase) + while reactor.getDelayedCalls(): + reactor.runUntilCurrent() + + def test_linearizer(self) -> None: + """Tests that a task is queued up behind an earlier task.""" linearizer = Linearizer() key = object() - d1 = linearizer.queue(key) - cm1 = yield d1 + _, acquired_d1, unblock1 = self._start_task(linearizer, key) + self.assertTrue(acquired_d1.called) + + _, acquired_d2, unblock2 = self._start_task(linearizer, key) + self.assertFalse(acquired_d2.called) - d2 = linearizer.queue(key) - self.assertFalse(d2.called) + # Once the first task is done, the second task can continue. + unblock1() + self.assertTrue(acquired_d2.called) - with cm1: - self.assertFalse(d2.called) + unblock2() - with (yield d2): - pass + def test_linearizer_is_queued(self) -> None: + """Tests `Linearizer.is_queued`. - @defer.inlineCallbacks - def test_linearizer_is_queued(self): + Runs through the same scenario as `test_linearizer`. + """ linearizer = Linearizer() key = object() - d1 = linearizer.queue(key) - cm1 = yield d1 + _, acquired_d1, unblock1 = self._start_task(linearizer, key) + self.assertTrue(acquired_d1.called) - # Since d1 gets called immediately, "is_queued" should return false. + # Since the first task acquires the lock immediately, "is_queued" should return + # false. self.assertFalse(linearizer.is_queued(key)) - d2 = linearizer.queue(key) - self.assertFalse(d2.called) + _, acquired_d2, unblock2 = self._start_task(linearizer, key) + self.assertFalse(acquired_d2.called) - # Now d2 is queued up behind successful completion of cm1 + # Now the second task is queued up behind the first. self.assertTrue(linearizer.is_queued(key)) - with cm1: - self.assertFalse(d2.called) - - # cm1 still not done, so d2 still queued. - self.assertTrue(linearizer.is_queued(key)) + unblock1() - # And now d2 is called and nothing is in the queue again + # And now the second task acquires the lock and nothing is in the queue again. + self.assertTrue(acquired_d2.called) self.assertFalse(linearizer.is_queued(key)) - with (yield d2): - self.assertFalse(linearizer.is_queued(key)) - + unblock2() self.assertFalse(linearizer.is_queued(key)) - def test_lots_of_queued_things(self): - # we have one slow thing, and lots of fast things queued up behind it. - # it should *not* explode the stack. + def test_lots_of_queued_things(self) -> None: + """Tests lots of fast things queued up behind a slow thing. + + The stack should *not* explode when the slow thing completes. + """ linearizer = Linearizer() + key = "" - @defer.inlineCallbacks - def func(i, sleep=False): + async def func(i: int) -> None: with LoggingContext("func(%s)" % i) as lc: - with (yield linearizer.queue("")): + with (await linearizer.queue(key)): self.assertEqual(current_context(), lc) - if sleep: - yield Clock(reactor).sleep(0) self.assertEqual(current_context(), lc) - func(0, sleep=True) + _, _, unblock = self._start_task(linearizer, key) for i in range(1, 100): - func(i) + defer.ensureDeferred(func(i)) - return func(1000) + d = defer.ensureDeferred(func(1000)) + unblock() + self.successResultOf(d) - @defer.inlineCallbacks - def test_multiple_entries(self): + def test_multiple_entries(self) -> None: + """Tests a `Linearizer` with a concurrency above 1.""" limiter = Linearizer(max_count=3) key = object() - d1 = limiter.queue(key) - cm1 = yield d1 - - d2 = limiter.queue(key) - cm2 = yield d2 - - d3 = limiter.queue(key) - cm3 = yield d3 - - d4 = limiter.queue(key) - self.assertFalse(d4.called) - - d5 = limiter.queue(key) - self.assertFalse(d5.called) + _, acquired_d1, unblock1 = self._start_task(limiter, key) + self.assertTrue(acquired_d1.called) - with cm1: - self.assertFalse(d4.called) - self.assertFalse(d5.called) + _, acquired_d2, unblock2 = self._start_task(limiter, key) + self.assertTrue(acquired_d2.called) - cm4 = yield d4 - self.assertFalse(d5.called) + _, acquired_d3, unblock3 = self._start_task(limiter, key) + self.assertTrue(acquired_d3.called) - with cm3: - self.assertFalse(d5.called) + # These next two tasks have to wait. + _, acquired_d4, unblock4 = self._start_task(limiter, key) + self.assertFalse(acquired_d4.called) - cm5 = yield d5 + _, acquired_d5, unblock5 = self._start_task(limiter, key) + self.assertFalse(acquired_d5.called) - with cm2: - pass + # Once the first task completes, the fourth task can continue. + unblock1() + self.assertTrue(acquired_d4.called) + self.assertFalse(acquired_d5.called) - with cm4: - pass + # Once the third task completes, the fifth task can continue. + unblock3() + self.assertTrue(acquired_d5.called) - with cm5: - pass + # Make all tasks finish. + unblock2() + unblock4() + unblock5() - d6 = limiter.queue(key) - with (yield d6): - pass + # The next task shouldn't have to wait. + _, acquired_d6, unblock6 = self._start_task(limiter, key) + self.assertTrue(acquired_d6) + unblock6() - @defer.inlineCallbacks - def test_cancellation(self): + def test_cancellation(self) -> None: + """Tests cancellation while waiting for a `Linearizer`.""" linearizer = Linearizer() key = object() - d1 = linearizer.queue(key) - cm1 = yield d1 + d1, acquired_d1, unblock1 = self._start_task(linearizer, key) + self.assertTrue(acquired_d1.called) - d2 = linearizer.queue(key) - self.assertFalse(d2.called) + # Create a second task, waiting for the first task. + d2, acquired_d2, _ = self._start_task(linearizer, key) + self.assertFalse(acquired_d2.called) - d3 = linearizer.queue(key) - self.assertFalse(d3.called) + # Create a third task, waiting for the second task. + d3, acquired_d3, unblock3 = self._start_task(linearizer, key) + self.assertFalse(acquired_d3.called) + # Cancel the waiting second task. d2.cancel() - with cm1: - pass + unblock1() + self.successResultOf(d1) self.assertTrue(d2.called) - try: - yield d2 - self.fail("Expected d2 to raise CancelledError") - except CancelledError: - pass - - with (yield d3): - pass + self.failureResultOf(d2, CancelledError) + + # The third task should continue running. + self.assertTrue( + acquired_d3.called, + "Third task did not get the lock after the second task was cancelled", + ) + unblock3() + self.successResultOf(d3) |