diff --git a/tests/handlers/test_worker_lock.py b/tests/handlers/test_worker_lock.py
index 3a4cf82094..6e9a15c8ee 100644
--- a/tests/handlers/test_worker_lock.py
+++ b/tests/handlers/test_worker_lock.py
@@ -27,6 +27,7 @@ from synapse.util import Clock
from tests import unittest
from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.utils import test_timeout
class WorkerLockTestCase(unittest.HomeserverTestCase):
@@ -50,6 +51,28 @@ class WorkerLockTestCase(unittest.HomeserverTestCase):
self.get_success(d2)
self.get_success(lock2.__aexit__(None, None, None))
+ def test_lock_contention(self) -> None:
+ """Test lock contention when a lot of locks wait on a single worker"""
+
+ # It takes around 0.5s on a 5+ years old laptop
+ with test_timeout(5):
+ nb_locks = 500
+ d = self._take_locks(nb_locks)
+ self.assertEqual(self.get_success(d), nb_locks)
+
+ async def _take_locks(self, nb_locks: int) -> int:
+ locks = [
+ self.hs.get_worker_locks_handler().acquire_lock("test_lock", "")
+ for _ in range(nb_locks)
+ ]
+
+ nb_locks_taken = 0
+ for lock in locks:
+ async with lock:
+ nb_locks_taken += 1
+
+ return nb_locks_taken
+
class WorkerLockWorkersTestCase(BaseMultiWorkerStreamTestCase):
def prepare(
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index d2220f8195..8437da1cdd 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -495,9 +495,9 @@ class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub."""
def __init__(self) -> None:
- self._subscribers_by_channel: Dict[
- bytes, Set["FakeRedisPubSubProtocol"]
- ] = defaultdict(set)
+ self._subscribers_by_channel: Dict[bytes, Set["FakeRedisPubSubProtocol"]] = (
+ defaultdict(set)
+ )
def add_subscriber(self, conn: "FakeRedisPubSubProtocol", channel: bytes) -> None:
"""A connection has called SUBSCRIBE"""
diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index 0a894ad081..9cfc6b224f 100644
--- a/tests/rest/client/test_filter.py
+++ b/tests/rest/client/test_filter.py
@@ -72,7 +72,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
def test_add_filter_non_local_user(self) -> None:
_is_mine = self.hs.is_mine
- self.hs.is_mine = lambda target_user: False # type: ignore[method-assign]
+ self.hs.is_mine = lambda target_user: False # type: ignore[assignment]
channel = self.make_request(
"POST",
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index b11a73e92b..d2f2ded487 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -1222,9 +1222,9 @@ class RoomJoinTestCase(RoomBase):
"""
# Register a dummy callback. Make it allow all room joins for now.
- return_value: Union[
- Literal["NOT_SPAM"], Tuple[Codes, dict], Codes
- ] = synapse.module_api.NOT_SPAM
+ return_value: Union[Literal["NOT_SPAM"], Tuple[Codes, dict], Codes] = (
+ synapse.module_api.NOT_SPAM
+ )
async def user_may_join_room(
userid: str,
@@ -1664,9 +1664,9 @@ class RoomMessagesTestCase(RoomBase):
expected_fields: dict,
) -> None:
class SpamCheck:
- mock_return_value: Union[
- str, bool, Codes, Tuple[Codes, JsonDict], bool
- ] = "NOT_SPAM"
+ mock_return_value: Union[str, bool, Codes, Tuple[Codes, JsonDict], bool] = (
+ "NOT_SPAM"
+ )
mock_content: Optional[JsonDict] = None
async def check_event_for_spam(
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 10cfe22d8e..daa68d78b9 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -87,8 +87,7 @@ class RestHelper:
expect_code: Literal[200] = ...,
extra_content: Optional[Dict] = ...,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
- ) -> str:
- ...
+ ) -> str: ...
@overload
def create_room_as(
@@ -100,8 +99,7 @@ class RestHelper:
expect_code: int = ...,
extra_content: Optional[Dict] = ...,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
- ) -> Optional[str]:
- ...
+ ) -> Optional[str]: ...
def create_room_as(
self,
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 249c6b39f7..d5b9996284 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -337,15 +337,15 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
"""Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion()
expires old entries correctly.
"""
- self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
- "1"
- ] = 100000
- self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
- "2"
- ] = 200000
- self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
- "3"
- ] = 300000
+ self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion["1"] = (
+ 100000
+ )
+ self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion["2"] = (
+ 200000
+ )
+ self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion["3"] = (
+ 300000
+ )
self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
# All entries within time frame
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index 01c5324802..1eab89f140 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -328,9 +328,11 @@ class MessageSearchTest(HomeserverTestCase):
self.assertEqual(
result["count"],
1 if expect_to_contain else 0,
- f"expected '{query}' to match '{self.PHRASE}'"
- if expect_to_contain
- else f"'{query}' unexpectedly matched '{self.PHRASE}'",
+ (
+ f"expected '{query}' to match '{self.PHRASE}'"
+ if expect_to_contain
+ else f"'{query}' unexpectedly matched '{self.PHRASE}'"
+ ),
)
self.assertEqual(
len(result["results"]),
@@ -346,9 +348,11 @@ class MessageSearchTest(HomeserverTestCase):
self.assertEqual(
result["count"],
1 if expect_to_contain else 0,
- f"expected '{query}' to match '{self.PHRASE}'"
- if expect_to_contain
- else f"'{query}' unexpectedly matched '{self.PHRASE}'",
+ (
+ f"expected '{query}' to match '{self.PHRASE}'"
+ if expect_to_contain
+ else f"'{query}' unexpectedly matched '{self.PHRASE}'"
+ ),
)
self.assertEqual(
len(result["results"]),
diff --git a/tests/unittest.py b/tests/unittest.py
index 33c9a384ea..6fe0cd4a2d 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -109,8 +109,7 @@ class _TypedFailure(Generic[_ExcType], Protocol):
"""Extension to twisted.Failure, where the 'value' has a certain type."""
@property
- def value(self) -> _ExcType:
- ...
+ def value(self) -> _ExcType: ...
def around(target: TV) -> Callable[[Callable[Concatenate[S, P], R]], None]:
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index d4268bc2e2..7cbb1007da 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -34,8 +34,7 @@ from tests import unittest
class UnblockFunction(Protocol):
- def __call__(self, pump_reactor: bool = True) -> None:
- ...
+ def __call__(self, pump_reactor: bool = True) -> None: ...
class LinearizerTestCase(unittest.TestCase):
diff --git a/tests/utils.py b/tests/utils.py
index b5dbd60a9c..9fd26ef348 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -21,7 +21,20 @@
import atexit
import os
-from typing import Any, Callable, Dict, List, Tuple, Type, TypeVar, Union, overload
+import signal
+from types import FrameType, TracebackType
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ List,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ overload,
+)
import attr
from typing_extensions import Literal, ParamSpec
@@ -121,13 +134,11 @@ def setupdb() -> None:
@overload
-def default_config(name: str, parse: Literal[False] = ...) -> Dict[str, object]:
- ...
+def default_config(name: str, parse: Literal[False] = ...) -> Dict[str, object]: ...
@overload
-def default_config(name: str, parse: Literal[True]) -> HomeServerConfig:
- ...
+def default_config(name: str, parse: Literal[True]) -> HomeServerConfig: ...
def default_config(
@@ -381,3 +392,30 @@ def checked_cast(type: Type[T], x: object) -> T:
"""
assert isinstance(x, type)
return x
+
+
+class TestTimeout(Exception):
+ pass
+
+
+class test_timeout:
+ def __init__(self, seconds: int, error_message: Optional[str] = None) -> None:
+ if error_message is None:
+ error_message = "test timed out after {}s.".format(seconds)
+ self.seconds = seconds
+ self.error_message = error_message
+
+ def handle_timeout(self, signum: int, frame: Optional[FrameType]) -> None:
+ raise TestTimeout(self.error_message)
+
+ def __enter__(self) -> None:
+ signal.signal(signal.SIGALRM, self.handle_timeout)
+ signal.alarm(self.seconds)
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ signal.alarm(0)
|