summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12358.misc1
-rw-r--r--synapse/util/async_helpers.py6
-rw-r--r--tests/util/test_linearizer.py51
3 files changed, 53 insertions, 5 deletions
diff --git a/changelog.d/12358.misc b/changelog.d/12358.misc
new file mode 100644
index 0000000000..fcacbcba5c
--- /dev/null
+++ b/changelog.d/12358.misc
@@ -0,0 +1 @@
+Fix a long-standing bug where `Linearizer`s could get stuck if a cancellation were to happen at the wrong time.
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 4b2a16a6a9..650e44de22 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -453,7 +453,11 @@ class Linearizer:
         #
         # This needs to happen while we hold the lock. We could put it on the
         # exit path, but that would slow down the uncontended case.
-        await self._clock.sleep(0)
+        try:
+            await self._clock.sleep(0)
+        except CancelledError:
+            self._release_lock(key, entry)
+            raise
 
         return entry
 
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index c2a209e637..47a1cfbdc1 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -13,7 +13,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Callable, Hashable, Tuple
+from typing import Hashable, Tuple
+
+from typing_extensions import Protocol
 
 from twisted.internet import defer, reactor
 from twisted.internet.base import ReactorBase
@@ -25,10 +27,15 @@ from synapse.util.async_helpers import Linearizer
 from tests import unittest
 
 
+class UnblockFunction(Protocol):
+    def __call__(self, pump_reactor: bool = True) -> None:
+        ...
+
+
 class LinearizerTestCase(unittest.TestCase):
     def _start_task(
         self, linearizer: Linearizer, key: Hashable
-    ) -> Tuple["Deferred[None]", "Deferred[None]", Callable[[], None]]:
+    ) -> Tuple["Deferred[None]", "Deferred[None]", UnblockFunction]:
         """Starts a task which acquires the linearizer lock, blocks, then completes.
 
         Args:
@@ -52,11 +59,12 @@ class LinearizerTestCase(unittest.TestCase):
 
         d = defer.ensureDeferred(task())
 
-        def unblock() -> None:
+        def unblock(pump_reactor: bool = True) -> None:
             unblock_d.callback(None)
             # The next task, if it exists, will acquire the lock and require a kick of
             # the reactor to advance.
-            self._pump()
+            if pump_reactor:
+                self._pump()
 
         return d, acquired_d, unblock
 
@@ -212,3 +220,38 @@ class LinearizerTestCase(unittest.TestCase):
         )
         unblock3()
         self.successResultOf(d3)
+
+    def test_cancellation_during_sleep(self) -> None:
+        """Tests cancellation during the sleep just after waiting for a `Linearizer`."""
+        linearizer = Linearizer()
+
+        key = object()
+
+        d1, acquired_d1, unblock1 = self._start_task(linearizer, key)
+        self.assertTrue(acquired_d1.called)
+
+        # Create a second task, waiting for the first task.
+        d2, acquired_d2, _ = self._start_task(linearizer, key)
+        self.assertFalse(acquired_d2.called)
+
+        # Create a third task, waiting for the second task.
+        d3, acquired_d3, unblock3 = self._start_task(linearizer, key)
+        self.assertFalse(acquired_d3.called)
+
+        # Once the first task completes, cancel the waiting second task while it is
+        # sleeping just after acquiring the lock.
+        unblock1(pump_reactor=False)
+        self.successResultOf(d1)
+        d2.cancel()
+        self._pump()
+
+        self.assertTrue(d2.called)
+        self.failureResultOf(d2, CancelledError)
+
+        # The third task should continue running.
+        self.assertTrue(
+            acquired_d3.called,
+            "Third task did not get the lock after the second task was cancelled",
+        )
+        unblock3()
+        self.successResultOf(d3)