summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_worker_lock.py23
-rw-r--r--tests/replication/_base.py6
-rw-r--r--tests/rest/client/test_filter.py2
-rw-r--r--tests/rest/client/test_rooms.py12
-rw-r--r--tests/rest/client/utils.py6
-rw-r--r--tests/storage/test_cleanup_extrems.py18
-rw-r--r--tests/storage/test_room_search.py16
-rw-r--r--tests/unittest.py3
-rw-r--r--tests/util/test_linearizer.py3
-rw-r--r--tests/utils.py48
10 files changed, 99 insertions, 38 deletions
diff --git a/tests/handlers/test_worker_lock.py b/tests/handlers/test_worker_lock.py

index 3a4cf82094..6e9a15c8ee 100644 --- a/tests/handlers/test_worker_lock.py +++ b/tests/handlers/test_worker_lock.py
@@ -27,6 +27,7 @@ from synapse.util import Clock from tests import unittest from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.utils import test_timeout class WorkerLockTestCase(unittest.HomeserverTestCase): @@ -50,6 +51,28 @@ class WorkerLockTestCase(unittest.HomeserverTestCase): self.get_success(d2) self.get_success(lock2.__aexit__(None, None, None)) + def test_lock_contention(self) -> None: + """Test lock contention when a lot of locks wait on a single worker""" + + # It takes around 0.5s on a 5+ years old laptop + with test_timeout(5): + nb_locks = 500 + d = self._take_locks(nb_locks) + self.assertEqual(self.get_success(d), nb_locks) + + async def _take_locks(self, nb_locks: int) -> int: + locks = [ + self.hs.get_worker_locks_handler().acquire_lock("test_lock", "") + for _ in range(nb_locks) + ] + + nb_locks_taken = 0 + for lock in locks: + async with lock: + nb_locks_taken += 1 + + return nb_locks_taken + class WorkerLockWorkersTestCase(BaseMultiWorkerStreamTestCase): def prepare( diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index d2220f8195..8437da1cdd 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py
@@ -495,9 +495,9 @@ class FakeRedisPubSubServer: """A fake Redis server for pub/sub.""" def __init__(self) -> None: - self._subscribers_by_channel: Dict[ - bytes, Set["FakeRedisPubSubProtocol"] - ] = defaultdict(set) + self._subscribers_by_channel: Dict[bytes, Set["FakeRedisPubSubProtocol"]] = ( + defaultdict(set) + ) def add_subscriber(self, conn: "FakeRedisPubSubProtocol", channel: bytes) -> None: """A connection has called SUBSCRIBE""" diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index 0a894ad081..9cfc6b224f 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py
@@ -72,7 +72,7 @@ class FilterTestCase(unittest.HomeserverTestCase): def test_add_filter_non_local_user(self) -> None: _is_mine = self.hs.is_mine - self.hs.is_mine = lambda target_user: False # type: ignore[method-assign] + self.hs.is_mine = lambda target_user: False # type: ignore[assignment] channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % (self.user_id), diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index b11a73e92b..d2f2ded487 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py
@@ -1222,9 +1222,9 @@ class RoomJoinTestCase(RoomBase): """ # Register a dummy callback. Make it allow all room joins for now. - return_value: Union[ - Literal["NOT_SPAM"], Tuple[Codes, dict], Codes - ] = synapse.module_api.NOT_SPAM + return_value: Union[Literal["NOT_SPAM"], Tuple[Codes, dict], Codes] = ( + synapse.module_api.NOT_SPAM + ) async def user_may_join_room( userid: str, @@ -1664,9 +1664,9 @@ class RoomMessagesTestCase(RoomBase): expected_fields: dict, ) -> None: class SpamCheck: - mock_return_value: Union[ - str, bool, Codes, Tuple[Codes, JsonDict], bool - ] = "NOT_SPAM" + mock_return_value: Union[str, bool, Codes, Tuple[Codes, JsonDict], bool] = ( + "NOT_SPAM" + ) mock_content: Optional[JsonDict] = None async def check_event_for_spam( diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 10cfe22d8e..daa68d78b9 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py
@@ -87,8 +87,7 @@ class RestHelper: expect_code: Literal[200] = ..., extra_content: Optional[Dict] = ..., custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ..., - ) -> str: - ... + ) -> str: ... @overload def create_room_as( @@ -100,8 +99,7 @@ class RestHelper: expect_code: int = ..., extra_content: Optional[Dict] = ..., custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ..., - ) -> Optional[str]: - ... + ) -> Optional[str]: ... def create_room_as( self, diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 249c6b39f7..d5b9996284 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py
@@ -337,15 +337,15 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): """Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion() expires old entries correctly. """ - self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ - "1" - ] = 100000 - self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ - "2" - ] = 200000 - self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ - "3" - ] = 300000 + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion["1"] = ( + 100000 + ) + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion["2"] = ( + 200000 + ) + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion["3"] = ( + 300000 + ) self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion() # All entries within time frame diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index 01c5324802..1eab89f140 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py
@@ -328,9 +328,11 @@ class MessageSearchTest(HomeserverTestCase): self.assertEqual( result["count"], 1 if expect_to_contain else 0, - f"expected '{query}' to match '{self.PHRASE}'" - if expect_to_contain - else f"'{query}' unexpectedly matched '{self.PHRASE}'", + ( + f"expected '{query}' to match '{self.PHRASE}'" + if expect_to_contain + else f"'{query}' unexpectedly matched '{self.PHRASE}'" + ), ) self.assertEqual( len(result["results"]), @@ -346,9 +348,11 @@ class MessageSearchTest(HomeserverTestCase): self.assertEqual( result["count"], 1 if expect_to_contain else 0, - f"expected '{query}' to match '{self.PHRASE}'" - if expect_to_contain - else f"'{query}' unexpectedly matched '{self.PHRASE}'", + ( + f"expected '{query}' to match '{self.PHRASE}'" + if expect_to_contain + else f"'{query}' unexpectedly matched '{self.PHRASE}'" + ), ) self.assertEqual( len(result["results"]), diff --git a/tests/unittest.py b/tests/unittest.py
index 33c9a384ea..6fe0cd4a2d 100644 --- a/tests/unittest.py +++ b/tests/unittest.py
@@ -109,8 +109,7 @@ class _TypedFailure(Generic[_ExcType], Protocol): """Extension to twisted.Failure, where the 'value' has a certain type.""" @property - def value(self) -> _ExcType: - ... + def value(self) -> _ExcType: ... def around(target: TV) -> Callable[[Callable[Concatenate[S, P], R]], None]: diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index d4268bc2e2..7cbb1007da 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py
@@ -34,8 +34,7 @@ from tests import unittest class UnblockFunction(Protocol): - def __call__(self, pump_reactor: bool = True) -> None: - ... + def __call__(self, pump_reactor: bool = True) -> None: ... class LinearizerTestCase(unittest.TestCase): diff --git a/tests/utils.py b/tests/utils.py
index b5dbd60a9c..9fd26ef348 100644 --- a/tests/utils.py +++ b/tests/utils.py
@@ -21,7 +21,20 @@ import atexit import os -from typing import Any, Callable, Dict, List, Tuple, Type, TypeVar, Union, overload +import signal +from types import FrameType, TracebackType +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + overload, +) import attr from typing_extensions import Literal, ParamSpec @@ -121,13 +134,11 @@ def setupdb() -> None: @overload -def default_config(name: str, parse: Literal[False] = ...) -> Dict[str, object]: - ... +def default_config(name: str, parse: Literal[False] = ...) -> Dict[str, object]: ... @overload -def default_config(name: str, parse: Literal[True]) -> HomeServerConfig: - ... +def default_config(name: str, parse: Literal[True]) -> HomeServerConfig: ... def default_config( @@ -381,3 +392,30 @@ def checked_cast(type: Type[T], x: object) -> T: """ assert isinstance(x, type) return x + + +class TestTimeout(Exception): + pass + + +class test_timeout: + def __init__(self, seconds: int, error_message: Optional[str] = None) -> None: + if error_message is None: + error_message = "test timed out after {}s.".format(seconds) + self.seconds = seconds + self.error_message = error_message + + def handle_timeout(self, signum: int, frame: Optional[FrameType]) -> None: + raise TestTimeout(self.error_message) + + def __enter__(self) -> None: + signal.signal(signal.SIGALRM, self.handle_timeout) + signal.alarm(self.seconds) + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + signal.alarm(0)