summary refs log tree commit diff
path: root/tests/util
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2022-11-22 17:35:54 -0500
committerGitHub <noreply@github.com>2022-11-22 17:35:54 -0500
commit4ae967cf6308e80b03da749f0cbaed36988e235e (patch)
treeada8f0dbd704f74c54feca002bb8a62c65021d26 /tests/util
parentOptimize `filter_events_for_client` for faster `/messages` - v2 (#14527) (diff)
downloadsynapse-4ae967cf6308e80b03da749f0cbaed36988e235e.tar.xz
Add missing type hints to test.util.caches (#14529)
Diffstat (limited to 'tests/util')
-rw-r--r--tests/util/caches/test_cached_call.py23
-rw-r--r--tests/util/caches/test_deferred_cache.py61
-rw-r--r--tests/util/caches/test_descriptors.py22
-rw-r--r--tests/util/caches/test_response_cache.py16
-rw-r--r--tests/util/caches/test_ttlcache.py8
5 files changed, 69 insertions, 61 deletions
diff --git a/tests/util/caches/test_cached_call.py b/tests/util/caches/test_cached_call.py
index 80b97167ba..9266f12590 100644
--- a/tests/util/caches/test_cached_call.py
+++ b/tests/util/caches/test_cached_call.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import NoReturn
 from unittest.mock import Mock
 
 from twisted.internet import defer
@@ -23,14 +24,14 @@ from tests.unittest import TestCase
 
 
 class CachedCallTestCase(TestCase):
-    def test_get(self):
+    def test_get(self) -> None:
         """
         Happy-path test case: makes a couple of calls and makes sure they behave
         correctly
         """
-        d = Deferred()
+        d: "Deferred[int]" = Deferred()
 
-        async def f():
+        async def f() -> int:
             return await d
 
         slow_call = Mock(side_effect=f)
@@ -43,7 +44,7 @@ class CachedCallTestCase(TestCase):
         # now fire off a couple of calls
         completed_results = []
 
-        async def r():
+        async def r() -> None:
             res = await cached_call.get()
             completed_results.append(res)
 
@@ -69,12 +70,12 @@ class CachedCallTestCase(TestCase):
         self.assertEqual(r3, 123)
         slow_call.assert_not_called()
 
-    def test_fast_call(self):
+    def test_fast_call(self) -> None:
         """
         Test the behaviour when the underlying function completes immediately
         """
 
-        async def f():
+        async def f() -> int:
             return 12
 
         fast_call = Mock(side_effect=f)
@@ -92,12 +93,12 @@ class CachedCallTestCase(TestCase):
 
 
 class RetryOnExceptionCachedCallTestCase(TestCase):
-    def test_get(self):
+    def test_get(self) -> None:
         # set up the RetryOnExceptionCachedCall around a function which will fail
         # (after a while)
-        d = Deferred()
+        d: "Deferred[int]" = Deferred()
 
-        async def f1():
+        async def f1() -> NoReturn:
             await d
             raise ValueError("moo")
 
@@ -110,7 +111,7 @@ class RetryOnExceptionCachedCallTestCase(TestCase):
         # now fire off a couple of calls
         completed_results = []
 
-        async def r():
+        async def r() -> None:
             try:
                 await cached_call.get()
             except Exception as e1:
@@ -137,7 +138,7 @@ class RetryOnExceptionCachedCallTestCase(TestCase):
         # to the getter
         d = Deferred()
 
-        async def f2():
+        async def f2() -> int:
             return await d
 
         slow_call.reset_mock()
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
index 02b99b466a..f74d82b1dc 100644
--- a/tests/util/caches/test_deferred_cache.py
+++ b/tests/util/caches/test_deferred_cache.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 from functools import partial
+from typing import List, Tuple
 
 from twisted.internet import defer
 
@@ -22,20 +23,20 @@ from tests.unittest import TestCase
 
 
 class DeferredCacheTestCase(TestCase):
-    def test_empty(self):
-        cache = DeferredCache("test")
+    def test_empty(self) -> None:
+        cache: DeferredCache[str, int] = DeferredCache("test")
         with self.assertRaises(KeyError):
             cache.get("foo")
 
-    def test_hit(self):
-        cache = DeferredCache("test")
+    def test_hit(self) -> None:
+        cache: DeferredCache[str, int] = DeferredCache("test")
         cache.prefill("foo", 123)
 
         self.assertEqual(self.successResultOf(cache.get("foo")), 123)
 
-    def test_hit_deferred(self):
-        cache = DeferredCache("test")
-        origin_d = defer.Deferred()
+    def test_hit_deferred(self) -> None:
+        cache: DeferredCache[str, int] = DeferredCache("test")
+        origin_d: "defer.Deferred[int]" = defer.Deferred()
         set_d = cache.set("k1", origin_d)
 
         # get should return an incomplete deferred
@@ -43,7 +44,7 @@ class DeferredCacheTestCase(TestCase):
         self.assertFalse(get_d.called)
 
         # add a callback that will make sure that the set_d gets called before the get_d
-        def check1(r):
+        def check1(r: str) -> str:
             self.assertTrue(set_d.called)
             return r
 
@@ -55,16 +56,16 @@ class DeferredCacheTestCase(TestCase):
         self.assertEqual(self.successResultOf(set_d), 99)
         self.assertEqual(self.successResultOf(get_d), 99)
 
-    def test_callbacks(self):
+    def test_callbacks(self) -> None:
         """Invalidation callbacks are called at the right time"""
-        cache = DeferredCache("test")
+        cache: DeferredCache[str, int] = DeferredCache("test")
         callbacks = set()
 
         # start with an entry, with a callback
         cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
 
         # now replace that entry with a pending result
-        origin_d = defer.Deferred()
+        origin_d: "defer.Deferred[int]" = defer.Deferred()
         set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
 
         # ... and also make a get request
@@ -89,15 +90,15 @@ class DeferredCacheTestCase(TestCase):
         cache.prefill("k1", 30)
         self.assertEqual(callbacks, {"set", "get"})
 
-    def test_set_fail(self):
-        cache = DeferredCache("test")
+    def test_set_fail(self) -> None:
+        cache: DeferredCache[str, int] = DeferredCache("test")
         callbacks = set()
 
         # start with an entry, with a callback
         cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
 
         # now replace that entry with a pending result
-        origin_d = defer.Deferred()
+        origin_d: defer.Deferred = defer.Deferred()
         set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
 
         # ... and also make a get request
@@ -126,9 +127,9 @@ class DeferredCacheTestCase(TestCase):
         cache.prefill("k1", 30)
         self.assertEqual(callbacks, {"prefill", "get2"})
 
-    def test_get_immediate(self):
-        cache = DeferredCache("test")
-        d1 = defer.Deferred()
+    def test_get_immediate(self) -> None:
+        cache: DeferredCache[str, int] = DeferredCache("test")
+        d1: "defer.Deferred[int]" = defer.Deferred()
         cache.set("key1", d1)
 
         # get_immediate should return default
@@ -142,27 +143,27 @@ class DeferredCacheTestCase(TestCase):
         v = cache.get_immediate("key1", 1)
         self.assertEqual(v, 2)
 
-    def test_invalidate(self):
-        cache = DeferredCache("test")
+    def test_invalidate(self) -> None:
+        cache: DeferredCache[Tuple[str], int] = DeferredCache("test")
         cache.prefill(("foo",), 123)
         cache.invalidate(("foo",))
 
         with self.assertRaises(KeyError):
             cache.get(("foo",))
 
-    def test_invalidate_all(self):
-        cache = DeferredCache("testcache")
+    def test_invalidate_all(self) -> None:
+        cache: DeferredCache[str, str] = DeferredCache("testcache")
 
         callback_record = [False, False]
 
-        def record_callback(idx):
+        def record_callback(idx: int) -> None:
             callback_record[idx] = True
 
         # add a couple of pending entries
-        d1 = defer.Deferred()
+        d1: "defer.Deferred[str]" = defer.Deferred()
         cache.set("key1", d1, partial(record_callback, 0))
 
-        d2 = defer.Deferred()
+        d2: "defer.Deferred[str]" = defer.Deferred()
         cache.set("key2", d2, partial(record_callback, 1))
 
         # lookup should return pending deferreds
@@ -193,8 +194,8 @@ class DeferredCacheTestCase(TestCase):
         with self.assertRaises(KeyError):
             cache.get("key1", None)
 
-    def test_eviction(self):
-        cache = DeferredCache(
+    def test_eviction(self) -> None:
+        cache: DeferredCache[int, str] = DeferredCache(
             "test", max_entries=2, apply_cache_factor_from_config=False
         )
 
@@ -208,8 +209,8 @@ class DeferredCacheTestCase(TestCase):
         cache.get(2)
         cache.get(3)
 
-    def test_eviction_lru(self):
-        cache = DeferredCache(
+    def test_eviction_lru(self) -> None:
+        cache: DeferredCache[int, str] = DeferredCache(
             "test", max_entries=2, apply_cache_factor_from_config=False
         )
 
@@ -227,8 +228,8 @@ class DeferredCacheTestCase(TestCase):
         cache.get(1)
         cache.get(3)
 
-    def test_eviction_iterable(self):
-        cache = DeferredCache(
+    def test_eviction_iterable(self) -> None:
+        cache: DeferredCache[int, List[str]] = DeferredCache(
             "test",
             max_entries=3,
             apply_cache_factor_from_config=False,
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 43475a307f..13f1edd533 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -13,11 +13,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Iterable, Set, Tuple
+from typing import Iterable, Set, Tuple, cast
 from unittest import mock
 
 from twisted.internet import defer, reactor
 from twisted.internet.defer import CancelledError, Deferred
+from twisted.internet.interfaces import IReactorTime
 
 from synapse.api.errors import SynapseError
 from synapse.logging.context import (
@@ -37,8 +38,8 @@ logger = logging.getLogger(__name__)
 
 
 def run_on_reactor():
-    d = defer.Deferred()
-    reactor.callLater(0, d.callback, 0)
+    d: "Deferred[int]" = defer.Deferred()
+    cast(IReactorTime, reactor).callLater(0, d.callback, 0)
     return make_deferred_yieldable(d)
 
 
@@ -224,7 +225,8 @@ class DescriptorTestCase(unittest.TestCase):
         callbacks: Set[str] = set()
 
         # set off an asynchronous request
-        obj.result = origin_d = defer.Deferred()
+        origin_d: Deferred = defer.Deferred()
+        obj.result = origin_d
 
         d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
         self.assertFalse(d1.called)
@@ -262,7 +264,7 @@ class DescriptorTestCase(unittest.TestCase):
         """Check that logcontexts are set and restored correctly when
         using the cache."""
 
-        complete_lookup = defer.Deferred()
+        complete_lookup: Deferred = defer.Deferred()
 
         class Cls:
             @descriptors.cached()
@@ -772,10 +774,14 @@ class CachedListDescriptorTestCase(unittest.TestCase):
 
             @descriptors.cachedList(cached_method_name="fn", list_name="args1")
             async def list_fn(self, args1, arg2):
-                assert current_context().name == "c1"
+                context = current_context()
+                assert isinstance(context, LoggingContext)
+                assert context.name == "c1"
                 # we want this to behave like an asynchronous function
                 await run_on_reactor()
-                assert current_context().name == "c1"
+                context = current_context()
+                assert isinstance(context, LoggingContext)
+                assert context.name == "c1"
                 return self.mock(args1, arg2)
 
         with LoggingContext("c1") as c1:
@@ -834,7 +840,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
                 return self.mock(args1)
 
         obj = Cls()
-        deferred_result = Deferred()
+        deferred_result: "Deferred[dict]" = Deferred()
         obj.mock.return_value = deferred_result
 
         # start off several concurrent lookups of the same key
diff --git a/tests/util/caches/test_response_cache.py b/tests/util/caches/test_response_cache.py
index 025b73e32f..f09eeecada 100644
--- a/tests/util/caches/test_response_cache.py
+++ b/tests/util/caches/test_response_cache.py
@@ -35,7 +35,7 @@ class ResponseCacheTestCase(TestCase):
                 (These have cache with a short timeout_ms=, shorter than will be tested through advancing the clock)
     """
 
-    def setUp(self):
+    def setUp(self) -> None:
         self.reactor, self.clock = get_clock()
 
     def with_cache(self, name: str, ms: int = 0) -> ResponseCache:
@@ -49,7 +49,7 @@ class ResponseCacheTestCase(TestCase):
         await self.clock.sleep(1)
         return o
 
-    def test_cache_hit(self):
+    def test_cache_hit(self) -> None:
         cache = self.with_cache("keeping_cache", ms=9001)
 
         expected_result = "howdy"
@@ -74,7 +74,7 @@ class ResponseCacheTestCase(TestCase):
             "cache should still have the result",
         )
 
-    def test_cache_miss(self):
+    def test_cache_miss(self) -> None:
         cache = self.with_cache("trashing_cache", ms=0)
 
         expected_result = "howdy"
@@ -90,7 +90,7 @@ class ResponseCacheTestCase(TestCase):
         )
         self.assertCountEqual([], cache.keys(), "cache should not have the result now")
 
-    def test_cache_expire(self):
+    def test_cache_expire(self) -> None:
         cache = self.with_cache("short_cache", ms=1000)
 
         expected_result = "howdy"
@@ -115,7 +115,7 @@ class ResponseCacheTestCase(TestCase):
         self.reactor.pump((2,))
         self.assertCountEqual([], cache.keys(), "cache should not have the result now")
 
-    def test_cache_wait_hit(self):
+    def test_cache_wait_hit(self) -> None:
         cache = self.with_cache("neutral_cache")
 
         expected_result = "howdy"
@@ -131,7 +131,7 @@ class ResponseCacheTestCase(TestCase):
 
         self.assertEqual(expected_result, self.successResultOf(wrap_d))
 
-    def test_cache_wait_expire(self):
+    def test_cache_wait_expire(self) -> None:
         cache = self.with_cache("medium_cache", ms=3000)
 
         expected_result = "howdy"
@@ -162,7 +162,7 @@ class ResponseCacheTestCase(TestCase):
         self.assertCountEqual([], cache.keys(), "cache should not have the result now")
 
     @parameterized.expand([(True,), (False,)])
-    def test_cache_context_nocache(self, should_cache: bool):
+    def test_cache_context_nocache(self, should_cache: bool) -> None:
         """If the callback clears the should_cache bit, the result should not be cached"""
         cache = self.with_cache("medium_cache", ms=3000)
 
@@ -170,7 +170,7 @@ class ResponseCacheTestCase(TestCase):
 
         call_count = 0
 
-        async def non_caching(o: str, cache_context: ResponseCacheContext[int]):
+        async def non_caching(o: str, cache_context: ResponseCacheContext[int]) -> str:
             nonlocal call_count
             call_count += 1
             await self.clock.sleep(1)
diff --git a/tests/util/caches/test_ttlcache.py b/tests/util/caches/test_ttlcache.py
index fe8314057d..679d1eb36b 100644
--- a/tests/util/caches/test_ttlcache.py
+++ b/tests/util/caches/test_ttlcache.py
@@ -20,11 +20,11 @@ from tests import unittest
 
 
 class CacheTestCase(unittest.TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.mock_timer = Mock(side_effect=lambda: 100.0)
-        self.cache = TTLCache("test_cache", self.mock_timer)
+        self.cache: TTLCache[str, str] = TTLCache("test_cache", self.mock_timer)
 
-    def test_get(self):
+    def test_get(self) -> None:
         """simple set/get tests"""
         self.cache.set("one", "1", 10)
         self.cache.set("two", "2", 20)
@@ -59,7 +59,7 @@ class CacheTestCase(unittest.TestCase):
         self.assertEqual(self.cache._metrics.hits, 4)
         self.assertEqual(self.cache._metrics.misses, 5)
 
-    def test_expiry(self):
+    def test_expiry(self) -> None:
         self.cache.set("one", "1", 10)
         self.cache.set("two", "2", 20)
         self.cache.set("three", "3", 30)