diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/handlers/test_worker_lock.py | 23 | ||||
-rw-r--r-- | tests/utils.py | 42 |
2 files changed, 64 insertions, 1 deletions
diff --git a/tests/handlers/test_worker_lock.py b/tests/handlers/test_worker_lock.py index 3a4cf82094..6e9a15c8ee 100644 --- a/tests/handlers/test_worker_lock.py +++ b/tests/handlers/test_worker_lock.py @@ -27,6 +27,7 @@ from synapse.util import Clock from tests import unittest from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.utils import test_timeout class WorkerLockTestCase(unittest.HomeserverTestCase): @@ -50,6 +51,28 @@ class WorkerLockTestCase(unittest.HomeserverTestCase): self.get_success(d2) self.get_success(lock2.__aexit__(None, None, None)) + def test_lock_contention(self) -> None: + """Test lock contention when a lot of locks wait on a single worker""" + + # It takes around 0.5s on a 5+ years old laptop + with test_timeout(5): + nb_locks = 500 + d = self._take_locks(nb_locks) + self.assertEqual(self.get_success(d), nb_locks) + + async def _take_locks(self, nb_locks: int) -> int: + locks = [ + self.hs.get_worker_locks_handler().acquire_lock("test_lock", "") + for _ in range(nb_locks) + ] + + nb_locks_taken = 0 + for lock in locks: + async with lock: + nb_locks_taken += 1 + + return nb_locks_taken + class WorkerLockWorkersTestCase(BaseMultiWorkerStreamTestCase): def prepare( diff --git a/tests/utils.py b/tests/utils.py index 757320ebee..9fd26ef348 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -21,7 +21,20 @@ import atexit import os -from typing import Any, Callable, Dict, List, Tuple, Type, TypeVar, Union, overload +import signal +from types import FrameType, TracebackType +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + overload, +) import attr from typing_extensions import Literal, ParamSpec @@ -379,3 +392,30 @@ def checked_cast(type: Type[T], x: object) -> T: """ assert isinstance(x, type) return x + + +class TestTimeout(Exception): + pass + + +class test_timeout: + def __init__(self, seconds: int, error_message: Optional[str] = None) -> None: + if error_message is None: + error_message = "test timed out after {}s.".format(seconds) + self.seconds = seconds + self.error_message = error_message + + def handle_timeout(self, signum: int, frame: Optional[FrameType]) -> None: + raise TestTimeout(self.error_message) + + def __enter__(self) -> None: + signal.signal(signal.SIGALRM, self.handle_timeout) + signal.alarm(self.seconds) + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + signal.alarm(0) |