diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index d82822d00d..cfd2882410 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -18,7 +18,7 @@
#
#
import traceback
-from typing import Generator, List, NoReturn, Optional
+from typing import Any, Coroutine, Generator, List, NoReturn, Optional, Tuple, TypeVar
from parameterized import parameterized_class
@@ -39,6 +39,7 @@ from synapse.util.async_helpers import (
ObservableDeferred,
concurrently_execute,
delay_cancellation,
+ gather_optional_coroutines,
stop_cancellation,
timeout_deferred,
)
@@ -46,6 +47,8 @@ from synapse.util.async_helpers import (
from tests.server import get_clock
from tests.unittest import TestCase
+T = TypeVar("T")
+
class ObservableDeferredTest(TestCase):
def test_succeed(self) -> None:
@@ -317,12 +320,19 @@ class ConcurrentlyExecuteTest(TestCase):
await concurrently_execute(callback, [1], 2)
except _TestException as e:
tb = traceback.extract_tb(e.__traceback__)
- # we expect to see "caller", "concurrently_execute", "callback",
- # and some magic from inside ensureDeferred that happens when .fail
- # is called.
+
+ # Remove twisted internals from the stack, as we don't care
+ # about the precise details.
+ tb = traceback.StackSummary(
+ t for t in tb if "/twisted/" not in t.filename
+ )
+
+ # we expect to see "caller", "concurrently_execute" at the top of the stack
self.assertEqual(tb[0].name, "caller")
self.assertEqual(tb[1].name, "concurrently_execute")
- self.assertEqual(tb[-2].name, "callback")
+ # ... some stack frames from the implementation of `concurrently_execute` ...
+ # and at the bottom of the stack we expect to see "callback"
+ self.assertEqual(tb[-1].name, "callback")
else:
self.fail("No exception thrown")
@@ -588,3 +598,106 @@ class AwakenableSleeperTests(TestCase):
sleeper.wake("name")
self.assertTrue(d1.called)
self.assertTrue(d2.called)
+
+
+class GatherCoroutineTests(TestCase):
+ """Tests for `gather_optional_coroutines`"""
+
+ def make_coroutine(self) -> Tuple[Coroutine[Any, Any, T], "defer.Deferred[T]"]:
+ """Returns a coroutine and a deferred that it is waiting on to resolve"""
+
+ d: "defer.Deferred[T]" = defer.Deferred()
+
+ async def inner() -> T:
+ with PreserveLoggingContext():
+ return await d
+
+ return inner(), d
+
+ def test_single(self) -> None:
+ "Test passing in a single coroutine works"
+
+ with LoggingContext("test_ctx") as text_ctx:
+ deferred: "defer.Deferred[None]"
+ coroutine, deferred = self.make_coroutine()
+
+ gather_deferred = defer.ensureDeferred(
+ gather_optional_coroutines(coroutine)
+ )
+
+ # We shouldn't have a result yet, and should be in the sentinel
+ # context.
+ self.assertNoResult(gather_deferred)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
+ # Resolving the deferred will resolve the coroutine
+ deferred.callback(None)
+
+ # All coroutines have resolved, and so we should have the results
+ result = self.successResultOf(gather_deferred)
+ self.assertEqual(result, (None,))
+
+ # We should be back in the normal context.
+ self.assertEqual(current_context(), text_ctx)
+
+ def test_multiple_resolve(self) -> None:
+ "Test passing in multiple coroutine that all resolve works"
+
+ with LoggingContext("test_ctx") as test_ctx:
+ deferred1: "defer.Deferred[int]"
+ coroutine1, deferred1 = self.make_coroutine()
+ deferred2: "defer.Deferred[str]"
+ coroutine2, deferred2 = self.make_coroutine()
+
+ gather_deferred = defer.ensureDeferred(
+ gather_optional_coroutines(coroutine1, coroutine2)
+ )
+
+ # We shouldn't have a result yet, and should be in the sentinel
+ # context.
+ self.assertNoResult(gather_deferred)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
+ # Even if we resolve one of the coroutines, we shouldn't have a result
+ # yet
+ deferred2.callback("test")
+ self.assertNoResult(gather_deferred)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
+ deferred1.callback(1)
+
+ # All coroutines have resolved, and so we should have the results
+ result = self.successResultOf(gather_deferred)
+ self.assertEqual(result, (1, "test"))
+
+ # We should be back in the normal context.
+ self.assertEqual(current_context(), test_ctx)
+
+ def test_multiple_fail(self) -> None:
+ "Test passing in multiple coroutine where one fails does the right thing"
+
+ with LoggingContext("test_ctx") as test_ctx:
+ deferred1: "defer.Deferred[int]"
+ coroutine1, deferred1 = self.make_coroutine()
+ deferred2: "defer.Deferred[str]"
+ coroutine2, deferred2 = self.make_coroutine()
+
+ gather_deferred = defer.ensureDeferred(
+ gather_optional_coroutines(coroutine1, coroutine2)
+ )
+
+ # We shouldn't have a result yet, and should be in the sentinel
+ # context.
+ self.assertNoResult(gather_deferred)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
+ # Throw an exception in one of the coroutines
+ exc = Exception("test")
+ deferred2.errback(exc)
+
+ # Expect the gather deferred to immediately fail
+ result_exc = self.failureResultOf(gather_deferred)
+ self.assertEqual(result_exc.value, exc)
+
+ # We should be back in the normal context.
+ self.assertEqual(current_context(), test_ctx)
diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py
index 13a4e6ddaa..c052ba2b75 100644
--- a/tests/util/test_check_dependencies.py
+++ b/tests/util/test_check_dependencies.py
@@ -109,10 +109,13 @@ class TestDependencyChecker(TestCase):
def test_checks_ignore_dev_dependencies(self) -> None:
"""Both generic and per-extra checks should ignore dev dependencies."""
- with patch(
- "synapse.util.check_dependencies.metadata.requires",
- return_value=["dummypkg >= 1; extra == 'mypy'"],
- ), patch("synapse.util.check_dependencies.RUNTIME_EXTRAS", {"cool-extra"}):
+ with (
+ patch(
+ "synapse.util.check_dependencies.metadata.requires",
+ return_value=["dummypkg >= 1; extra == 'mypy'"],
+ ),
+ patch("synapse.util.check_dependencies.RUNTIME_EXTRAS", {"cool-extra"}),
+ ):
# We're testing that none of these calls raise.
with self.mock_installed_package(None):
check_requirements()
@@ -141,10 +144,13 @@ class TestDependencyChecker(TestCase):
def test_check_for_extra_dependencies(self) -> None:
"""Complain if a package required for an extra is missing or old."""
- with patch(
- "synapse.util.check_dependencies.metadata.requires",
- return_value=["dummypkg >= 1; extra == 'cool-extra'"],
- ), patch("synapse.util.check_dependencies.RUNTIME_EXTRAS", {"cool-extra"}):
+ with (
+ patch(
+ "synapse.util.check_dependencies.metadata.requires",
+ return_value=["dummypkg >= 1; extra == 'cool-extra'"],
+ ),
+ patch("synapse.util.check_dependencies.RUNTIME_EXTRAS", {"cool-extra"}),
+ ):
with self.mock_installed_package(None):
self.assertRaises(DependencyException, check_requirements, "cool-extra")
with self.mock_installed_package(old):
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index 7cbb1007da..7510657b85 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -19,9 +19,7 @@
#
#
-from typing import Hashable, Tuple
-
-from typing_extensions import Protocol
+from typing import Hashable, Protocol, Tuple
from twisted.internet import defer, reactor
from twisted.internet.base import ReactorBase
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index af1199ef8a..9254bff79b 100644
--- a/tests/util/test_stream_change_cache.py
+++ b/tests/util/test_stream_change_cache.py
@@ -53,8 +53,8 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
# return True, whether it's a known entity or not.
self.assertTrue(cache.has_entity_changed("user@foo.com", 0))
self.assertTrue(cache.has_entity_changed("not@here.website", 0))
- self.assertTrue(cache.has_entity_changed("user@foo.com", 3))
- self.assertTrue(cache.has_entity_changed("not@here.website", 3))
+ self.assertTrue(cache.has_entity_changed("user@foo.com", 2))
+ self.assertTrue(cache.has_entity_changed("not@here.website", 2))
def test_entity_has_changed_pops_off_start(self) -> None:
"""
@@ -76,9 +76,11 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
self.assertTrue("user@foo.com" not in cache._entity_to_key)
self.assertEqual(
- cache.get_all_entities_changed(3).entities, ["user@elsewhere.org"]
+ cache.get_all_entities_changed(2).entities,
+ ["bar@baz.net", "user@elsewhere.org"],
)
- self.assertFalse(cache.get_all_entities_changed(2).hit)
+ self.assertFalse(cache.get_all_entities_changed(1).hit)
+ self.assertTrue(cache.get_all_entities_changed(2).hit)
# If we update an existing entity, it keeps the two existing entities
cache.entity_has_changed("bar@baz.net", 5)
@@ -89,7 +91,8 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
cache.get_all_entities_changed(3).entities,
["user@elsewhere.org", "bar@baz.net"],
)
- self.assertFalse(cache.get_all_entities_changed(2).hit)
+ self.assertFalse(cache.get_all_entities_changed(1).hit)
+ self.assertTrue(cache.get_all_entities_changed(2).hit)
def test_get_all_entities_changed(self) -> None:
"""
@@ -114,7 +117,8 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
self.assertEqual(
cache.get_all_entities_changed(3).entities, ["user@elsewhere.org"]
)
- self.assertFalse(cache.get_all_entities_changed(1).hit)
+ self.assertFalse(cache.get_all_entities_changed(0).hit)
+ self.assertTrue(cache.get_all_entities_changed(1).hit)
# ... later, things gest more updates
cache.entity_has_changed("user@foo.com", 5)
@@ -149,7 +153,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
# With no entities, it returns True for the past, present, and False for
# the future.
self.assertTrue(cache.has_any_entity_changed(0))
- self.assertTrue(cache.has_any_entity_changed(1))
+ self.assertFalse(cache.has_any_entity_changed(1))
self.assertFalse(cache.has_any_entity_changed(2))
# We add an entity
@@ -251,3 +255,28 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
# Unknown entities will return None
self.assertEqual(cache.get_max_pos_of_last_change("not@here.website"), None)
+
+ def test_all_entities_changed(self) -> None:
+ """
+ `StreamChangeCache.all_entities_changed(...)` will mark all entites as changed.
+ """
+ cache = StreamChangeCache("#test", 1, max_size=10)
+
+ cache.entity_has_changed("user@foo.com", 2)
+ cache.entity_has_changed("bar@baz.net", 3)
+ cache.entity_has_changed("user@elsewhere.org", 4)
+
+ cache.all_entities_changed(5)
+
+ # Everything should be marked as changed before the stream position where the
+ # change occurred.
+ self.assertTrue(cache.has_entity_changed("user@foo.com", 4))
+ self.assertTrue(cache.has_entity_changed("bar@baz.net", 4))
+ self.assertTrue(cache.has_entity_changed("user@elsewhere.org", 4))
+
+ # Nothing should be marked as changed at/after the stream position where the
+ # change occurred. In other words, nothing has changed since the stream position
+ # 5.
+ self.assertFalse(cache.has_entity_changed("user@foo.com", 5))
+ self.assertFalse(cache.has_entity_changed("bar@baz.net", 5))
+ self.assertFalse(cache.has_entity_changed("user@elsewhere.org", 5))
diff --git a/tests/util/test_stringutils.py b/tests/util/test_stringutils.py
index 646fd2163e..34c2395ecf 100644
--- a/tests/util/test_stringutils.py
+++ b/tests/util/test_stringutils.py
@@ -20,7 +20,11 @@
#
from synapse.api.errors import SynapseError
-from synapse.util.stringutils import assert_valid_client_secret, base62_encode
+from synapse.util.stringutils import (
+ assert_valid_client_secret,
+ base62_encode,
+ is_namedspaced_grammar,
+)
from .. import unittest
@@ -58,3 +62,25 @@ class StringUtilsTestCase(unittest.TestCase):
self.assertEqual("10", base62_encode(62))
self.assertEqual("1c", base62_encode(100))
self.assertEqual("001c", base62_encode(100, minwidth=4))
+
+ def test_namespaced_identifier(self) -> None:
+ self.assertTrue(is_namedspaced_grammar("test"))
+ self.assertTrue(is_namedspaced_grammar("m.test"))
+ self.assertTrue(is_namedspaced_grammar("org.matrix.test"))
+ self.assertTrue(is_namedspaced_grammar("org.matrix.msc1234"))
+ self.assertTrue(is_namedspaced_grammar("test"))
+ self.assertTrue(is_namedspaced_grammar("t-e_s.t"))
+
+ # Must start with letter.
+ self.assertFalse(is_namedspaced_grammar("1test"))
+ self.assertFalse(is_namedspaced_grammar("-test"))
+ self.assertFalse(is_namedspaced_grammar("_test"))
+ self.assertFalse(is_namedspaced_grammar(".test"))
+
+ # Must contain only a-z, 0-9, -, _, ..
+ self.assertFalse(is_namedspaced_grammar("test/"))
+ self.assertFalse(is_namedspaced_grammar('test"'))
+ self.assertFalse(is_namedspaced_grammar("testö"))
+
+ # Must be < 255 characters.
+ self.assertFalse(is_namedspaced_grammar("t" * 256))
diff --git a/tests/util/test_task_scheduler.py b/tests/util/test_task_scheduler.py
index 30f0510c9f..7f6e63bd49 100644
--- a/tests/util/test_task_scheduler.py
+++ b/tests/util/test_task_scheduler.py
@@ -18,8 +18,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
-
-from typing import Optional, Tuple
+from typing import List, Optional, Tuple
from twisted.internet.task import deferLater
from twisted.test.proto_helpers import MemoryReactor
@@ -104,38 +103,48 @@ class TestTaskScheduler(HomeserverTestCase):
)
)
- # This is to give the time to the active tasks to finish
+ def get_tasks_of_status(status: TaskStatus) -> List[ScheduledTask]:
+ tasks = (
+ self.get_success(self.task_scheduler.get_task(task_id))
+ for task_id in task_ids
+ )
+ return [t for t in tasks if t is not None and t.status == status]
+
+ # At this point, there should be MAX_CONCURRENT_RUNNING_TASKS active tasks and
+ # one scheduled task.
+ self.assertEqual(
+ len(get_tasks_of_status(TaskStatus.ACTIVE)),
+ TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS,
+ )
+ self.assertEqual(
+ len(get_tasks_of_status(TaskStatus.SCHEDULED)),
+ 1,
+ )
+
+ # Give the time to the active tasks to finish
self.reactor.advance(1)
- # Check that only MAX_CONCURRENT_RUNNING_TASKS tasks has run and that one
+ # Check that MAX_CONCURRENT_RUNNING_TASKS tasks have run and that one
# is still scheduled.
- tasks = [
- self.get_success(self.task_scheduler.get_task(task_id))
- for task_id in task_ids
- ]
-
- self.assertEquals(
- len(
- [t for t in tasks if t is not None and t.status == TaskStatus.COMPLETE]
- ),
+ self.assertEqual(
+ len(get_tasks_of_status(TaskStatus.COMPLETE)),
TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS,
)
+ scheduled_tasks = get_tasks_of_status(TaskStatus.SCHEDULED)
+ self.assertEqual(len(scheduled_tasks), 1)
- scheduled_tasks = [
- t for t in tasks if t is not None and t.status == TaskStatus.ACTIVE
- ]
- self.assertEquals(len(scheduled_tasks), 1)
+ # The scheduled task should start 0.1s after the first of the active tasks
+ # finishes
+ self.reactor.advance(0.1)
+ self.assertEqual(len(get_tasks_of_status(TaskStatus.ACTIVE)), 1)
- # We need to wait for the next run of the scheduler loop
- self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
+ # ... and should finally complete after another second
self.reactor.advance(1)
-
- # Check that the last task has been properly executed after the next scheduler loop run
prev_scheduled_task = self.get_success(
self.task_scheduler.get_task(scheduled_tasks[0].id)
)
assert prev_scheduled_task is not None
- self.assertEquals(
+ self.assertEqual(
prev_scheduled_task.status,
TaskStatus.COMPLETE,
)
diff --git a/tests/util/test_threepids.py b/tests/util/test_threepids.py
deleted file mode 100644
index 15575cc572..0000000000
--- a/tests/util/test_threepids.py
+++ /dev/null
@@ -1,55 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright 2020 Dirk Klimpel
-# Copyright (C) 2023 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-
-from synapse.util.threepids import canonicalise_email
-
-from tests.unittest import HomeserverTestCase
-
-
-class CanonicaliseEmailTests(HomeserverTestCase):
- def test_no_at(self) -> None:
- with self.assertRaises(ValueError):
- canonicalise_email("address-without-at.bar")
-
- def test_two_at(self) -> None:
- with self.assertRaises(ValueError):
- canonicalise_email("foo@foo@test.bar")
-
- def test_bad_format(self) -> None:
- with self.assertRaises(ValueError):
- canonicalise_email("user@bad.example.net@good.example.com")
-
- def test_valid_format(self) -> None:
- self.assertEqual(canonicalise_email("foo@test.bar"), "foo@test.bar")
-
- def test_domain_to_lower(self) -> None:
- self.assertEqual(canonicalise_email("foo@TEST.BAR"), "foo@test.bar")
-
- def test_domain_with_umlaut(self) -> None:
- self.assertEqual(canonicalise_email("foo@Öumlaut.com"), "foo@öumlaut.com")
-
- def test_address_casefold(self) -> None:
- self.assertEqual(
- canonicalise_email("Strauß@Example.com"), "strauss@example.com"
- )
-
- def test_address_trim(self) -> None:
- self.assertEqual(canonicalise_email(" foo@test.bar "), "foo@test.bar")
diff --git a/tests/util/test_wheel_timer.py b/tests/util/test_wheel_timer.py
index 173a7cfaec..6fa575a18e 100644
--- a/tests/util/test_wheel_timer.py
+++ b/tests/util/test_wheel_timer.py
@@ -28,53 +28,55 @@ class WheelTimerTestCase(unittest.TestCase):
def test_single_insert_fetch(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
- obj = object()
- wheel.insert(100, obj, 150)
+ wheel.insert(100, "1", 150)
self.assertListEqual(wheel.fetch(101), [])
self.assertListEqual(wheel.fetch(110), [])
self.assertListEqual(wheel.fetch(120), [])
self.assertListEqual(wheel.fetch(130), [])
self.assertListEqual(wheel.fetch(149), [])
- self.assertListEqual(wheel.fetch(156), [obj])
+ self.assertListEqual(wheel.fetch(156), ["1"])
self.assertListEqual(wheel.fetch(170), [])
def test_multi_insert(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
- obj1 = object()
- obj2 = object()
- obj3 = object()
- wheel.insert(100, obj1, 150)
- wheel.insert(105, obj2, 130)
- wheel.insert(106, obj3, 160)
+ wheel.insert(100, "1", 150)
+ wheel.insert(105, "2", 130)
+ wheel.insert(106, "3", 160)
self.assertListEqual(wheel.fetch(110), [])
- self.assertListEqual(wheel.fetch(135), [obj2])
+ self.assertListEqual(wheel.fetch(135), ["2"])
self.assertListEqual(wheel.fetch(149), [])
- self.assertListEqual(wheel.fetch(158), [obj1])
+ self.assertListEqual(wheel.fetch(158), ["1"])
self.assertListEqual(wheel.fetch(160), [])
- self.assertListEqual(wheel.fetch(200), [obj3])
+ self.assertListEqual(wheel.fetch(200), ["3"])
self.assertListEqual(wheel.fetch(210), [])
def test_insert_past(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
- obj = object()
- wheel.insert(100, obj, 50)
- self.assertListEqual(wheel.fetch(120), [obj])
+ wheel.insert(100, "1", 50)
+ self.assertListEqual(wheel.fetch(120), ["1"])
def test_insert_past_multi(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
- obj1 = object()
- obj2 = object()
- obj3 = object()
- wheel.insert(100, obj1, 150)
- wheel.insert(100, obj2, 140)
- wheel.insert(100, obj3, 50)
- self.assertListEqual(wheel.fetch(110), [obj3])
+ wheel.insert(100, "1", 150)
+ wheel.insert(100, "2", 140)
+ wheel.insert(100, "3", 50)
+ self.assertListEqual(wheel.fetch(110), ["3"])
self.assertListEqual(wheel.fetch(120), [])
- self.assertListEqual(wheel.fetch(147), [obj2])
- self.assertListEqual(wheel.fetch(200), [obj1])
+ self.assertListEqual(wheel.fetch(147), ["2"])
+ self.assertListEqual(wheel.fetch(200), ["1"])
self.assertListEqual(wheel.fetch(240), [])
+
+ def test_multi_insert_then_past(self) -> None:
+ wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
+
+ wheel.insert(100, "1", 150)
+ wheel.insert(100, "2", 160)
+ wheel.insert(100, "3", 155)
+
+ self.assertListEqual(wheel.fetch(110), [])
+ self.assertListEqual(wheel.fetch(158), ["1"])
diff --git a/tests/utils.py b/tests/utils.py
index 9fd26ef348..57986c18bc 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -28,6 +28,7 @@ from typing import (
Callable,
Dict,
List,
+ Literal,
Optional,
Tuple,
Type,
@@ -37,7 +38,7 @@ from typing import (
)
import attr
-from typing_extensions import Literal, ParamSpec
+from typing_extensions import ParamSpec
from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions
@@ -181,7 +182,6 @@ def default_config(
"max_mau_value": 50,
"mau_trial_days": 0,
"mau_stats_only": False,
- "mau_limits_reserved_threepids": [],
"admin_contact": None,
"rc_message": {"per_second": 10000, "burst_count": 10000},
"rc_registration": {"per_second": 10000, "burst_count": 10000},
@@ -200,9 +200,8 @@ def default_config(
"per_user": {"per_second": 10000, "burst_count": 10000},
},
"rc_3pid_validation": {"per_second": 10000, "burst_count": 10000},
- "saml2_enabled": False,
+ "rc_presence": {"per_user": {"per_second": 10000, "burst_count": 10000}},
"public_baseurl": None,
- "default_identity_server": None,
"key_refresh_interval": 24 * 60 * 60 * 1000,
"old_signing_keys": {},
"tls_fingerprints": [],
@@ -399,11 +398,24 @@ class TestTimeout(Exception):
class test_timeout:
+ """
+ FIXME: This implementation is not robust against other code tight-looping and
+ preventing the signals propagating and timing out the test. You may need to add
+ `time.sleep(0.1)` to your code in order to allow this timeout to work correctly.
+
+ ```py
+ with test_timeout(3):
+ while True:
+ my_checking_func()
+ time.sleep(0.1)
+ ```
+ """
+
def __init__(self, seconds: int, error_message: Optional[str] = None) -> None:
- if error_message is None:
- error_message = "test timed out after {}s.".format(seconds)
+ self.error_message = f"Test timed out after {seconds}s"
+ if error_message is not None:
+ self.error_message += f": {error_message}"
self.seconds = seconds
- self.error_message = error_message
def handle_timeout(self, signum: int, frame: Optional[FrameType]) -> None:
raise TestTimeout(self.error_message)
|