diff --git a/tests/util/caches/test_cached_call.py b/tests/util/caches/test_cached_call.py
index 80b97167ba..9266f12590 100644
--- a/tests/util/caches/test_cached_call.py
+++ b/tests/util/caches/test_cached_call.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import NoReturn
from unittest.mock import Mock
from twisted.internet import defer
@@ -23,14 +24,14 @@ from tests.unittest import TestCase
class CachedCallTestCase(TestCase):
- def test_get(self):
+ def test_get(self) -> None:
"""
Happy-path test case: makes a couple of calls and makes sure they behave
correctly
"""
- d = Deferred()
+ d: "Deferred[int]" = Deferred()
- async def f():
+ async def f() -> int:
return await d
slow_call = Mock(side_effect=f)
@@ -43,7 +44,7 @@ class CachedCallTestCase(TestCase):
# now fire off a couple of calls
completed_results = []
- async def r():
+ async def r() -> None:
res = await cached_call.get()
completed_results.append(res)
@@ -69,12 +70,12 @@ class CachedCallTestCase(TestCase):
self.assertEqual(r3, 123)
slow_call.assert_not_called()
- def test_fast_call(self):
+ def test_fast_call(self) -> None:
"""
Test the behaviour when the underlying function completes immediately
"""
- async def f():
+ async def f() -> int:
return 12
fast_call = Mock(side_effect=f)
@@ -92,12 +93,12 @@ class CachedCallTestCase(TestCase):
class RetryOnExceptionCachedCallTestCase(TestCase):
- def test_get(self):
+ def test_get(self) -> None:
# set up the RetryOnExceptionCachedCall around a function which will fail
# (after a while)
- d = Deferred()
+ d: "Deferred[int]" = Deferred()
- async def f1():
+ async def f1() -> NoReturn:
await d
raise ValueError("moo")
@@ -110,7 +111,7 @@ class RetryOnExceptionCachedCallTestCase(TestCase):
# now fire off a couple of calls
completed_results = []
- async def r():
+ async def r() -> None:
try:
await cached_call.get()
except Exception as e1:
@@ -137,7 +138,7 @@ class RetryOnExceptionCachedCallTestCase(TestCase):
# to the getter
d = Deferred()
- async def f2():
+ async def f2() -> int:
return await d
slow_call.reset_mock()
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
index 02b99b466a..f74d82b1dc 100644
--- a/tests/util/caches/test_deferred_cache.py
+++ b/tests/util/caches/test_deferred_cache.py
@@ -13,6 +13,7 @@
# limitations under the License.
from functools import partial
+from typing import List, Tuple
from twisted.internet import defer
@@ -22,20 +23,20 @@ from tests.unittest import TestCase
class DeferredCacheTestCase(TestCase):
- def test_empty(self):
- cache = DeferredCache("test")
+ def test_empty(self) -> None:
+ cache: DeferredCache[str, int] = DeferredCache("test")
with self.assertRaises(KeyError):
cache.get("foo")
- def test_hit(self):
- cache = DeferredCache("test")
+ def test_hit(self) -> None:
+ cache: DeferredCache[str, int] = DeferredCache("test")
cache.prefill("foo", 123)
self.assertEqual(self.successResultOf(cache.get("foo")), 123)
- def test_hit_deferred(self):
- cache = DeferredCache("test")
- origin_d = defer.Deferred()
+ def test_hit_deferred(self) -> None:
+ cache: DeferredCache[str, int] = DeferredCache("test")
+ origin_d: "defer.Deferred[int]" = defer.Deferred()
set_d = cache.set("k1", origin_d)
# get should return an incomplete deferred
@@ -43,7 +44,7 @@ class DeferredCacheTestCase(TestCase):
self.assertFalse(get_d.called)
# add a callback that will make sure that the set_d gets called before the get_d
- def check1(r):
+ def check1(r: str) -> str:
self.assertTrue(set_d.called)
return r
@@ -55,16 +56,16 @@ class DeferredCacheTestCase(TestCase):
self.assertEqual(self.successResultOf(set_d), 99)
self.assertEqual(self.successResultOf(get_d), 99)
- def test_callbacks(self):
+ def test_callbacks(self) -> None:
"""Invalidation callbacks are called at the right time"""
- cache = DeferredCache("test")
+ cache: DeferredCache[str, int] = DeferredCache("test")
callbacks = set()
# start with an entry, with a callback
cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
# now replace that entry with a pending result
- origin_d = defer.Deferred()
+ origin_d: "defer.Deferred[int]" = defer.Deferred()
set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
# ... and also make a get request
@@ -89,15 +90,15 @@ class DeferredCacheTestCase(TestCase):
cache.prefill("k1", 30)
self.assertEqual(callbacks, {"set", "get"})
- def test_set_fail(self):
- cache = DeferredCache("test")
+ def test_set_fail(self) -> None:
+ cache: DeferredCache[str, int] = DeferredCache("test")
callbacks = set()
# start with an entry, with a callback
cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
# now replace that entry with a pending result
- origin_d = defer.Deferred()
+ origin_d: defer.Deferred = defer.Deferred()
set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
# ... and also make a get request
@@ -126,9 +127,9 @@ class DeferredCacheTestCase(TestCase):
cache.prefill("k1", 30)
self.assertEqual(callbacks, {"prefill", "get2"})
- def test_get_immediate(self):
- cache = DeferredCache("test")
- d1 = defer.Deferred()
+ def test_get_immediate(self) -> None:
+ cache: DeferredCache[str, int] = DeferredCache("test")
+ d1: "defer.Deferred[int]" = defer.Deferred()
cache.set("key1", d1)
# get_immediate should return default
@@ -142,27 +143,27 @@ class DeferredCacheTestCase(TestCase):
v = cache.get_immediate("key1", 1)
self.assertEqual(v, 2)
- def test_invalidate(self):
- cache = DeferredCache("test")
+ def test_invalidate(self) -> None:
+ cache: DeferredCache[Tuple[str], int] = DeferredCache("test")
cache.prefill(("foo",), 123)
cache.invalidate(("foo",))
with self.assertRaises(KeyError):
cache.get(("foo",))
- def test_invalidate_all(self):
- cache = DeferredCache("testcache")
+ def test_invalidate_all(self) -> None:
+ cache: DeferredCache[str, str] = DeferredCache("testcache")
callback_record = [False, False]
- def record_callback(idx):
+ def record_callback(idx: int) -> None:
callback_record[idx] = True
# add a couple of pending entries
- d1 = defer.Deferred()
+ d1: "defer.Deferred[str]" = defer.Deferred()
cache.set("key1", d1, partial(record_callback, 0))
- d2 = defer.Deferred()
+ d2: "defer.Deferred[str]" = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1))
# lookup should return pending deferreds
@@ -193,8 +194,8 @@ class DeferredCacheTestCase(TestCase):
with self.assertRaises(KeyError):
cache.get("key1", None)
- def test_eviction(self):
- cache = DeferredCache(
+ def test_eviction(self) -> None:
+ cache: DeferredCache[int, str] = DeferredCache(
"test", max_entries=2, apply_cache_factor_from_config=False
)
@@ -208,8 +209,8 @@ class DeferredCacheTestCase(TestCase):
cache.get(2)
cache.get(3)
- def test_eviction_lru(self):
- cache = DeferredCache(
+ def test_eviction_lru(self) -> None:
+ cache: DeferredCache[int, str] = DeferredCache(
"test", max_entries=2, apply_cache_factor_from_config=False
)
@@ -227,8 +228,8 @@ class DeferredCacheTestCase(TestCase):
cache.get(1)
cache.get(3)
- def test_eviction_iterable(self):
- cache = DeferredCache(
+ def test_eviction_iterable(self) -> None:
+ cache: DeferredCache[int, List[str]] = DeferredCache(
"test",
max_entries=3,
apply_cache_factor_from_config=False,
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 48e616ac74..13f1edd533 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Set
+from typing import Iterable, Set, Tuple, cast
from unittest import mock
from twisted.internet import defer, reactor
from twisted.internet.defer import CancelledError, Deferred
+from twisted.internet.interfaces import IReactorTime
from synapse.api.errors import SynapseError
from synapse.logging.context import (
@@ -28,7 +29,7 @@ from synapse.logging.context import (
make_deferred_yieldable,
)
from synapse.util.caches import descriptors
-from synapse.util.caches.descriptors import cached, cachedList, lru_cache
+from synapse.util.caches.descriptors import cached, cachedList
from tests import unittest
from tests.test_utils import get_awaitable_result
@@ -36,41 +37,9 @@ from tests.test_utils import get_awaitable_result
logger = logging.getLogger(__name__)
-class LruCacheDecoratorTestCase(unittest.TestCase):
- def test_base(self):
- class Cls:
- def __init__(self):
- self.mock = mock.Mock()
-
- @lru_cache()
- def fn(self, arg1, arg2):
- return self.mock(arg1, arg2)
-
- obj = Cls()
- obj.mock.return_value = "fish"
- r = obj.fn(1, 2)
- self.assertEqual(r, "fish")
- obj.mock.assert_called_once_with(1, 2)
- obj.mock.reset_mock()
-
- # a call with different params should call the mock again
- obj.mock.return_value = "chips"
- r = obj.fn(1, 3)
- self.assertEqual(r, "chips")
- obj.mock.assert_called_once_with(1, 3)
- obj.mock.reset_mock()
-
- # the two values should now be cached
- r = obj.fn(1, 2)
- self.assertEqual(r, "fish")
- r = obj.fn(1, 3)
- self.assertEqual(r, "chips")
- obj.mock.assert_not_called()
-
-
def run_on_reactor():
- d = defer.Deferred()
- reactor.callLater(0, d.callback, 0)
+ d: "Deferred[int]" = defer.Deferred()
+ cast(IReactorTime, reactor).callLater(0, d.callback, 0)
return make_deferred_yieldable(d)
@@ -256,7 +225,8 @@ class DescriptorTestCase(unittest.TestCase):
callbacks: Set[str] = set()
# set off an asynchronous request
- obj.result = origin_d = defer.Deferred()
+ origin_d: Deferred = defer.Deferred()
+ obj.result = origin_d
d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
self.assertFalse(d1.called)
@@ -294,7 +264,7 @@ class DescriptorTestCase(unittest.TestCase):
"""Check that logcontexts are set and restored correctly when
using the cache."""
- complete_lookup = defer.Deferred()
+ complete_lookup: Deferred = defer.Deferred()
class Cls:
@descriptors.cached()
@@ -478,10 +448,10 @@ class DescriptorTestCase(unittest.TestCase):
@cached(cache_context=True)
async def func2(self, key, cache_context):
- return self.func3(key, on_invalidate=cache_context.invalidate)
+ return await self.func3(key, on_invalidate=cache_context.invalidate)
- @lru_cache(cache_context=True)
- def func3(self, key, cache_context):
+ @cached(cache_context=True)
+ async def func3(self, key, cache_context):
self.invalidate = cache_context.invalidate
return 42
@@ -804,10 +774,14 @@ class CachedListDescriptorTestCase(unittest.TestCase):
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
async def list_fn(self, args1, arg2):
- assert current_context().name == "c1"
+ context = current_context()
+ assert isinstance(context, LoggingContext)
+ assert context.name == "c1"
# we want this to behave like an asynchronous function
await run_on_reactor()
- assert current_context().name == "c1"
+ context = current_context()
+ assert isinstance(context, LoggingContext)
+ assert context.name == "c1"
return self.mock(args1, arg2)
with LoggingContext("c1") as c1:
@@ -866,7 +840,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
return self.mock(args1)
obj = Cls()
- deferred_result = Deferred()
+ deferred_result: "Deferred[dict]" = Deferred()
obj.mock.return_value = deferred_result
# start off several concurrent lookups of the same key
@@ -1008,3 +982,34 @@ class CachedListDescriptorTestCase(unittest.TestCase):
obj.inner_context_was_finished, "Tried to restart a finished logcontext"
)
self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
+ def test_num_args_mismatch(self):
+ """
+ Make sure someone does not accidentally use @cachedList on a method with
+ a mismatch in the number args to the underlying single cache method.
+ """
+
+ class Cls:
+ @descriptors.cached(tree=True)
+ def fn(self, room_id, event_id):
+ pass
+
+ # This is wrong ❌. `@cachedList` expects to be given the same number
+ # of arguments as the underlying cached function, just with one of
+ # the arguments being an iterable
+ @descriptors.cachedList(cached_method_name="fn", list_name="keys")
+ def list_fn(self, keys: Iterable[Tuple[str, str]]):
+ pass
+
+ # Corrected syntax ✅
+ #
+ # @cachedList(cached_method_name="fn", list_name="event_ids")
+ # async def list_fn(
+ # self, room_id: str, event_ids: Collection[str],
+ # )
+
+ obj = Cls()
+
+ # Make sure this raises an error about the arg mismatch
+ with self.assertRaises(TypeError):
+ obj.list_fn([("foo", "bar")])
diff --git a/tests/util/caches/test_response_cache.py b/tests/util/caches/test_response_cache.py
index 025b73e32f..f09eeecada 100644
--- a/tests/util/caches/test_response_cache.py
+++ b/tests/util/caches/test_response_cache.py
@@ -35,7 +35,7 @@ class ResponseCacheTestCase(TestCase):
(These have cache with a short timeout_ms=, shorter than will be tested through advancing the clock)
"""
- def setUp(self):
+ def setUp(self) -> None:
self.reactor, self.clock = get_clock()
def with_cache(self, name: str, ms: int = 0) -> ResponseCache:
@@ -49,7 +49,7 @@ class ResponseCacheTestCase(TestCase):
await self.clock.sleep(1)
return o
- def test_cache_hit(self):
+ def test_cache_hit(self) -> None:
cache = self.with_cache("keeping_cache", ms=9001)
expected_result = "howdy"
@@ -74,7 +74,7 @@ class ResponseCacheTestCase(TestCase):
"cache should still have the result",
)
- def test_cache_miss(self):
+ def test_cache_miss(self) -> None:
cache = self.with_cache("trashing_cache", ms=0)
expected_result = "howdy"
@@ -90,7 +90,7 @@ class ResponseCacheTestCase(TestCase):
)
self.assertCountEqual([], cache.keys(), "cache should not have the result now")
- def test_cache_expire(self):
+ def test_cache_expire(self) -> None:
cache = self.with_cache("short_cache", ms=1000)
expected_result = "howdy"
@@ -115,7 +115,7 @@ class ResponseCacheTestCase(TestCase):
self.reactor.pump((2,))
self.assertCountEqual([], cache.keys(), "cache should not have the result now")
- def test_cache_wait_hit(self):
+ def test_cache_wait_hit(self) -> None:
cache = self.with_cache("neutral_cache")
expected_result = "howdy"
@@ -131,7 +131,7 @@ class ResponseCacheTestCase(TestCase):
self.assertEqual(expected_result, self.successResultOf(wrap_d))
- def test_cache_wait_expire(self):
+ def test_cache_wait_expire(self) -> None:
cache = self.with_cache("medium_cache", ms=3000)
expected_result = "howdy"
@@ -162,7 +162,7 @@ class ResponseCacheTestCase(TestCase):
self.assertCountEqual([], cache.keys(), "cache should not have the result now")
@parameterized.expand([(True,), (False,)])
- def test_cache_context_nocache(self, should_cache: bool):
+ def test_cache_context_nocache(self, should_cache: bool) -> None:
"""If the callback clears the should_cache bit, the result should not be cached"""
cache = self.with_cache("medium_cache", ms=3000)
@@ -170,7 +170,7 @@ class ResponseCacheTestCase(TestCase):
call_count = 0
- async def non_caching(o: str, cache_context: ResponseCacheContext[int]):
+ async def non_caching(o: str, cache_context: ResponseCacheContext[int]) -> str:
nonlocal call_count
call_count += 1
await self.clock.sleep(1)
diff --git a/tests/util/caches/test_ttlcache.py b/tests/util/caches/test_ttlcache.py
index fe8314057d..679d1eb36b 100644
--- a/tests/util/caches/test_ttlcache.py
+++ b/tests/util/caches/test_ttlcache.py
@@ -20,11 +20,11 @@ from tests import unittest
class CacheTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.mock_timer = Mock(side_effect=lambda: 100.0)
- self.cache = TTLCache("test_cache", self.mock_timer)
+ self.cache: TTLCache[str, str] = TTLCache("test_cache", self.mock_timer)
- def test_get(self):
+ def test_get(self) -> None:
"""simple set/get tests"""
self.cache.set("one", "1", 10)
self.cache.set("two", "2", 20)
@@ -59,7 +59,7 @@ class CacheTestCase(unittest.TestCase):
self.assertEqual(self.cache._metrics.hits, 4)
self.assertEqual(self.cache._metrics.misses, 5)
- def test_expiry(self):
+ def test_expiry(self) -> None:
self.cache.set("one", "1", 10)
self.cache.set("two", "2", 20)
self.cache.set("three", "3", 30)
diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py
index 5d1aa025d1..6913de24b9 100644
--- a/tests/util/test_check_dependencies.py
+++ b/tests/util/test_check_dependencies.py
@@ -40,7 +40,10 @@ class TestDependencyChecker(TestCase):
def mock_installed_package(
self, distribution: Optional[DummyDistribution]
) -> Generator[None, None, None]:
- """Pretend that looking up any distribution yields the given `distribution`."""
+ """Pretend that looking up any package yields the given `distribution`.
+
+ If `distribution = None`, we pretend that the package is not installed.
+ """
def mock_distribution(name: str):
if distribution is None:
@@ -81,7 +84,7 @@ class TestDependencyChecker(TestCase):
self.assertRaises(DependencyException, check_requirements)
def test_checks_ignore_dev_dependencies(self) -> None:
- """Bot generic and per-extra checks should ignore dev dependencies."""
+ """Both generic and per-extra checks should ignore dev dependencies."""
with patch(
"synapse.util.check_dependencies.metadata.requires",
return_value=["dummypkg >= 1; extra == 'mypy'"],
@@ -142,3 +145,16 @@ class TestDependencyChecker(TestCase):
with self.mock_installed_package(new_release_candidate):
# should not raise
check_requirements()
+
+ def test_setuptools_rust_ignored(self) -> None:
+ """Test a workaround for a `poetry build` problem. Reproduces #13926."""
+ with patch(
+ "synapse.util.check_dependencies.metadata.requires",
+ return_value=["setuptools_rust >= 1.3"],
+ ):
+ with self.mock_installed_package(None):
+ # should not raise, even if setuptools_rust is not installed
+ check_requirements()
+ with self.mock_installed_package(old):
+ # We also ignore old versions of setuptools_rust
+ check_requirements()
diff --git a/tests/util/test_macaroons.py b/tests/util/test_macaroons.py
index 32125f7bb7..40754a4711 100644
--- a/tests/util/test_macaroons.py
+++ b/tests/util/test_macaroons.py
@@ -84,34 +84,6 @@ class MacaroonGeneratorTestCase(TestCase):
)
self.assertEqual(user_id, "@user:tesths")
- def test_short_term_login_token(self):
- """Test the generation and verification of short-term login tokens"""
- token = self.macaroon_generator.generate_short_term_login_token(
- user_id="@user:tesths",
- auth_provider_id="oidc",
- auth_provider_session_id="sid",
- duration_in_ms=2 * 60 * 1000,
- )
-
- info = self.macaroon_generator.verify_short_term_login_token(token)
- self.assertEqual(info.user_id, "@user:tesths")
- self.assertEqual(info.auth_provider_id, "oidc")
- self.assertEqual(info.auth_provider_session_id, "sid")
-
- # Raises with another secret key
- with self.assertRaises(MacaroonVerificationFailedException):
- self.other_macaroon_generator.verify_short_term_login_token(token)
-
- # Wait a minute
- self.reactor.pump([60])
- # Shouldn't raise
- self.macaroon_generator.verify_short_term_login_token(token)
- # Wait another minute
- self.reactor.pump([60])
- # Should raise since it expired
- with self.assertRaises(MacaroonVerificationFailedException):
- self.macaroon_generator.verify_short_term_login_token(token)
-
def test_oidc_session_token(self):
"""Test the generation and verification of OIDC session cookies"""
state = "arandomstate"
|