summary refs log tree commit diff
diff options
context:
space:
mode:
authorSumner Evans <sumner@beeper.com>2022-05-30 02:41:13 -0600
committerGitHub <noreply@github.com>2022-05-30 09:41:13 +0100
commitbda460039941d5f7ae5c66544d5b428c0b33bd22 (patch)
treefe20f696c4253249a8e5482b6993b48fc62d442c
parentAdd a background job to automatically delete stale devices (#12855) (diff)
downloadsynapse-bda460039941d5f7ae5c66544d5b428c0b33bd22.tar.xz
LockStore: fix acquiring a lock via `LockStore.try_acquire_lock` (#12832)
Signed-off-by: Sumner Evans <sumner@beeper.com>
-rw-r--r--changelog.d/12832.bugfix1
-rw-r--r--synapse/storage/databases/main/lock.py19
-rw-r--r--tests/storage/databases/main/test_lock.py54
3 files changed, 73 insertions, 1 deletions
diff --git a/changelog.d/12832.bugfix b/changelog.d/12832.bugfix
new file mode 100644
index 0000000000..497d5184ea
--- /dev/null
+++ b/changelog.d/12832.bugfix
@@ -0,0 +1 @@
+Fixed a bug which allowed multiple async operations to access database locks concurrently. Contributed by @sumnerevans @ Beeper.
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index bedacaf0d7..2d7633fbd5 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 import logging
 from types import TracebackType
-from typing import TYPE_CHECKING, Optional, Tuple, Type
+from typing import TYPE_CHECKING, Optional, Set, Tuple, Type
 from weakref import WeakValueDictionary
 
 from twisted.internet.interfaces import IReactorCore
@@ -84,6 +84,8 @@ class LockStore(SQLBaseStore):
             self._on_shutdown,
         )
 
+        self._acquiring_locks: Set[Tuple[str, str]] = set()
+
     @wrap_as_background_process("LockStore._on_shutdown")
     async def _on_shutdown(self) -> None:
         """Called when the server is shutting down"""
@@ -103,6 +105,21 @@ class LockStore(SQLBaseStore):
         context manager if the lock is successfully acquired, which *must* be
         used (otherwise the lock will leak).
         """
+        if (lock_name, lock_key) in self._acquiring_locks:
+            return None
+        try:
+            self._acquiring_locks.add((lock_name, lock_key))
+            return await self._try_acquire_lock(lock_name, lock_key)
+        finally:
+            self._acquiring_locks.discard((lock_name, lock_key))
+
+    async def _try_acquire_lock(
+        self, lock_name: str, lock_key: str
+    ) -> Optional["Lock"]:
+        """Try to acquire a lock for the given name/key. Will return an async
+        context manager if the lock is successfully acquired, which *must* be
+        used (otherwise the lock will leak).
+        """
 
         # Check if this process has taken out a lock and if it's still valid.
         lock = self._live_tokens.get((lock_name, lock_key))
diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
index 74c6224eb6..3cc2a58d8d 100644
--- a/tests/storage/databases/main/test_lock.py
+++ b/tests/storage/databases/main/test_lock.py
@@ -12,6 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.internet import defer, reactor
+from twisted.internet.base import ReactorBase
+from twisted.internet.defer import Deferred
+
 from synapse.server import HomeServer
 from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS
 
@@ -22,6 +26,56 @@ class LockTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, hs: HomeServer):
         self.store = hs.get_datastores().main
 
+    def test_acquire_contention(self):
+        # Track the number of tasks holding the lock.
+        # Should be at most 1.
+        in_lock = 0
+        max_in_lock = 0
+
+        release_lock: "Deferred[None]" = Deferred()
+
+        async def task():
+            nonlocal in_lock
+            nonlocal max_in_lock
+
+            lock = await self.store.try_acquire_lock("name", "key")
+            if not lock:
+                return
+
+            async with lock:
+                in_lock += 1
+                max_in_lock = max(max_in_lock, in_lock)
+
+                # Block to allow other tasks to attempt to take the lock.
+                await release_lock
+
+                in_lock -= 1
+
+        # Start 3 tasks.
+        task1 = defer.ensureDeferred(task())
+        task2 = defer.ensureDeferred(task())
+        task3 = defer.ensureDeferred(task())
+
+        # Give the reactor a kick so that the database transaction returns.
+        self.pump()
+
+        release_lock.callback(None)
+
+        # Run the tasks to completion.
+        # To work around `Linearizer`s using a different reactor to sleep when
+        # contended (#12841), we call `runUntilCurrent` on
+        # `twisted.internet.reactor`, which is a different reactor to that used
+        # by the homeserver.
+        assert isinstance(reactor, ReactorBase)
+        self.get_success(task1)
+        reactor.runUntilCurrent()
+        self.get_success(task2)
+        reactor.runUntilCurrent()
+        self.get_success(task3)
+
+        # At most one task should have held the lock at a time.
+        self.assertEqual(max_in_lock, 1)
+
     def test_simple_lock(self):
         """Test that we can take out a lock and that while we hold it nobody
         else can take it out.