summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/storage/databases/main/test_lock.py283
1 files changed, 282 insertions, 1 deletions
diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
index 56cb49d9b5..ad454f6dd8 100644
--- a/tests/storage/databases/main/test_lock.py
+++ b/tests/storage/databases/main/test_lock.py
@@ -166,4 +166,285 @@ class LockTestCase(unittest.HomeserverTestCase):
         # Now call the shutdown code
         self.get_success(self.store._on_shutdown())
 
-        self.assertEqual(self.store._live_tokens, {})
+        self.assertEqual(self.store._live_lock_tokens, {})
+
+
+class ReadWriteLockTestCase(unittest.HomeserverTestCase):
+    """Test the read/write lock implementation."""
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
+
+    def test_acquire_write_contention(self) -> None:
+        """Test that we can only acquire one write lock at a time"""
+        # Track the number of tasks holding the lock.
+        # Should be at most 1.
+        in_lock = 0
+        max_in_lock = 0
+
+        release_lock: "Deferred[None]" = Deferred()
+
+        async def task() -> None:
+            nonlocal in_lock
+            nonlocal max_in_lock
+
+            lock = await self.store.try_acquire_read_write_lock(
+                "name", "key", write=True
+            )
+            if not lock:
+                return
+
+            async with lock:
+                in_lock += 1
+                max_in_lock = max(max_in_lock, in_lock)
+
+                # Block to allow other tasks to attempt to take the lock.
+                await release_lock
+
+                in_lock -= 1
+
+        # Start 3 tasks.
+        task1 = defer.ensureDeferred(task())
+        task2 = defer.ensureDeferred(task())
+        task3 = defer.ensureDeferred(task())
+
+        # Give the reactor a kick so that the database transaction returns.
+        self.pump()
+
+        release_lock.callback(None)
+
+        # Run the tasks to completion.
+        # To work around `Linearizer`s using a different reactor to sleep when
+        # contended (#12841), we call `runUntilCurrent` on
+        # `twisted.internet.reactor`, which is a different reactor to that used
+        # by the homeserver.
+        assert isinstance(reactor, ReactorBase)
+        self.get_success(task1)
+        reactor.runUntilCurrent()
+        self.get_success(task2)
+        reactor.runUntilCurrent()
+        self.get_success(task3)
+
+        # At most one task should have held the lock at a time.
+        self.assertEqual(max_in_lock, 1)
+
+    def test_acquire_multiple_reads(self) -> None:
+        """Test that we can acquire multiple read locks at a time"""
+        # Track the number of tasks holding the lock.
+        in_lock = 0
+        max_in_lock = 0
+
+        release_lock: "Deferred[None]" = Deferred()
+
+        async def task() -> None:
+            nonlocal in_lock
+            nonlocal max_in_lock
+
+            lock = await self.store.try_acquire_read_write_lock(
+                "name", "key", write=False
+            )
+            if not lock:
+                return
+
+            async with lock:
+                in_lock += 1
+                max_in_lock = max(max_in_lock, in_lock)
+
+                # Block to allow other tasks to attempt to take the lock.
+                await release_lock
+
+                in_lock -= 1
+
+        # Start 3 tasks.
+        task1 = defer.ensureDeferred(task())
+        task2 = defer.ensureDeferred(task())
+        task3 = defer.ensureDeferred(task())
+
+        # Give the reactor a kick so that the database transaction returns.
+        self.pump()
+
+        release_lock.callback(None)
+
+        # Run the tasks to completion.
+        # To work around `Linearizer`s using a different reactor to sleep when
+        # contended (#12841), we call `runUntilCurrent` on
+        # `twisted.internet.reactor`, which is a different reactor to that used
+        # by the homeserver.
+        assert isinstance(reactor, ReactorBase)
+        self.get_success(task1)
+        reactor.runUntilCurrent()
+        self.get_success(task2)
+        reactor.runUntilCurrent()
+        self.get_success(task3)
+
+        # At most one task should have held the lock at a time.
+        self.assertEqual(max_in_lock, 3)
+
+    def test_write_lock_acquired(self) -> None:
+        """Test that we can take out a write lock and that while we hold it
+        nobody else can take it out.
+        """
+        # First to acquire this lock, so it should complete
+        lock = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        assert lock is not None
+
+        # Enter the context manager
+        self.get_success(lock.__aenter__())
+
+        # Attempting to acquire the lock again fails, as both read and write.
+        lock2 = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        self.assertIsNone(lock2)
+
+        lock3 = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=False)
+        )
+        self.assertIsNone(lock3)
+
+        # Calling `is_still_valid` reports true.
+        self.assertTrue(self.get_success(lock.is_still_valid()))
+
+        # Drop the lock
+        self.get_success(lock.__aexit__(None, None, None))
+
+        # We can now acquire the lock again.
+        lock4 = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        assert lock4 is not None
+        self.get_success(lock4.__aenter__())
+        self.get_success(lock4.__aexit__(None, None, None))
+
+    def test_read_lock_acquired(self) -> None:
+        """Test that we can take out a read lock and that while we hold it
+        only other reads can use it.
+        """
+        # First to acquire this lock, so it should complete
+        lock = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=False)
+        )
+        assert lock is not None
+
+        # Enter the context manager
+        self.get_success(lock.__aenter__())
+
+        # Attempting to acquire the write lock fails
+        lock2 = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        self.assertIsNone(lock2)
+
+        # Attempting to acquire a read lock succeeds
+        lock3 = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=False)
+        )
+        assert lock3 is not None
+        self.get_success(lock3.__aenter__())
+
+        # Calling `is_still_valid` reports true.
+        self.assertTrue(self.get_success(lock.is_still_valid()))
+
+        # Drop the first lock
+        self.get_success(lock.__aexit__(None, None, None))
+
+        # Attempting to acquire the write lock still fails, as lock3 is still
+        # active.
+        lock4 = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        self.assertIsNone(lock4)
+
+        # Drop the still open third lock
+        self.get_success(lock3.__aexit__(None, None, None))
+
+        # We can now acquire the lock again.
+        lock5 = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        assert lock5 is not None
+        self.get_success(lock5.__aenter__())
+        self.get_success(lock5.__aexit__(None, None, None))
+
+    def test_maintain_lock(self) -> None:
+        """Test that we don't time out locks while they're still active (lock is
+        renewed in the background if the process is still alive)"""
+
+        lock = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        assert lock is not None
+
+        self.get_success(lock.__aenter__())
+
+        # Wait for ages with the lock, we should not be able to get the lock.
+        self.reactor.advance(5 * _LOCK_TIMEOUT_MS / 1000)
+        self.pump()
+
+        lock2 = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        self.assertIsNone(lock2)
+
+        self.get_success(lock.__aexit__(None, None, None))
+
+    def test_timeout_lock(self) -> None:
+        """Test that we time out locks if they're not updated for ages"""
+
+        lock = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        assert lock is not None
+
+        self.get_success(lock.__aenter__())
+
+        # We simulate the process getting stuck by cancelling the looping call
+        # that keeps the lock active.
+        lock._looping_call.stop()
+
+        # Wait for the lock to timeout.
+        self.reactor.advance(2 * _LOCK_TIMEOUT_MS / 1000)
+
+        lock2 = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        self.assertIsNotNone(lock2)
+
+        self.assertFalse(self.get_success(lock.is_still_valid()))
+
+    def test_drop(self) -> None:
+        """Test that dropping the context manager means we stop renewing the lock"""
+
+        lock = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        self.assertIsNotNone(lock)
+
+        del lock
+
+        # Wait for the lock to timeout.
+        self.reactor.advance(2 * _LOCK_TIMEOUT_MS / 1000)
+
+        lock2 = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        self.assertIsNotNone(lock2)
+
+    def test_shutdown(self) -> None:
+        """Test that shutting down Synapse releases the locks"""
+        # Acquire two locks
+        lock = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        self.assertIsNotNone(lock)
+        lock2 = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key2", write=True)
+        )
+        self.assertIsNotNone(lock2)
+
+        # Now call the shutdown code
+        self.get_success(self.store._on_shutdown())
+
+        self.assertEqual(self.store._live_read_write_lock_tokens, {})