summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/16840.misc1
-rw-r--r--synapse/handlers/worker_lock.py15
-rw-r--r--tests/handlers/test_worker_lock.py23
-rw-r--r--tests/utils.py42
4 files changed, 74 insertions, 7 deletions
diff --git a/changelog.d/16840.misc b/changelog.d/16840.misc
new file mode 100644
index 0000000000..1175e6de71
--- /dev/null
+++ b/changelog.d/16840.misc
@@ -0,0 +1 @@
+Improve lock performance when a lot of locks are all waiting for a single lock to be released.
diff --git a/synapse/handlers/worker_lock.py b/synapse/handlers/worker_lock.py
index a870fd1124..7e578cf462 100644
--- a/synapse/handlers/worker_lock.py
+++ b/synapse/handlers/worker_lock.py
@@ -182,12 +182,15 @@ class WorkerLocksHandler:
         if not locks:
             return
 
-        def _wake_deferred(deferred: defer.Deferred) -> None:
-            if not deferred.called:
-                deferred.callback(None)
-
-        for lock in locks:
-            self._clock.call_later(0, _wake_deferred, lock.deferred)
+        def _wake_all_locks(
+            locks: Collection[Union[WaitingLock, WaitingMultiLock]]
+        ) -> None:
+            for lock in locks:
+                deferred = lock.deferred
+                if not deferred.called:
+                    deferred.callback(None)
+
+        self._clock.call_later(0, _wake_all_locks, locks)
 
     @wrap_as_background_process("_cleanup_locks")
     async def _cleanup_locks(self) -> None:
diff --git a/tests/handlers/test_worker_lock.py b/tests/handlers/test_worker_lock.py
index 3a4cf82094..6e9a15c8ee 100644
--- a/tests/handlers/test_worker_lock.py
+++ b/tests/handlers/test_worker_lock.py
@@ -27,6 +27,7 @@ from synapse.util import Clock
 
 from tests import unittest
 from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.utils import test_timeout
 
 
 class WorkerLockTestCase(unittest.HomeserverTestCase):
@@ -50,6 +51,28 @@ class WorkerLockTestCase(unittest.HomeserverTestCase):
         self.get_success(d2)
         self.get_success(lock2.__aexit__(None, None, None))
 
+    def test_lock_contention(self) -> None:
+        """Test lock contention when a lot of locks wait on a single worker"""
+
+        # It takes around 0.5s on a 5+ years old laptop
+        with test_timeout(5):
+            nb_locks = 500
+            d = self._take_locks(nb_locks)
+            self.assertEqual(self.get_success(d), nb_locks)
+
+    async def _take_locks(self, nb_locks: int) -> int:
+        locks = [
+            self.hs.get_worker_locks_handler().acquire_lock("test_lock", "")
+            for _ in range(nb_locks)
+        ]
+
+        nb_locks_taken = 0
+        for lock in locks:
+            async with lock:
+                nb_locks_taken += 1
+
+        return nb_locks_taken
+
 
 class WorkerLockWorkersTestCase(BaseMultiWorkerStreamTestCase):
     def prepare(
diff --git a/tests/utils.py b/tests/utils.py
index 757320ebee..9fd26ef348 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -21,7 +21,20 @@
 
 import atexit
 import os
-from typing import Any, Callable, Dict, List, Tuple, Type, TypeVar, Union, overload
+import signal
+from types import FrameType, TracebackType
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    List,
+    Optional,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+    overload,
+)
 
 import attr
 from typing_extensions import Literal, ParamSpec
@@ -379,3 +392,30 @@ def checked_cast(type: Type[T], x: object) -> T:
     """
     assert isinstance(x, type)
     return x
+
+
+class TestTimeout(Exception):
+    pass
+
+
+class test_timeout:
+    def __init__(self, seconds: int, error_message: Optional[str] = None) -> None:
+        if error_message is None:
+            error_message = "test timed out after {}s.".format(seconds)
+        self.seconds = seconds
+        self.error_message = error_message
+
+    def handle_timeout(self, signum: int, frame: Optional[FrameType]) -> None:
+        raise TestTimeout(self.error_message)
+
+    def __enter__(self) -> None:
+        signal.signal(signal.SIGALRM, self.handle_timeout)
+        signal.alarm(self.seconds)
+
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
+        signal.alarm(0)