summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-09-20 07:48:55 -0400
committerGitHub <noreply@github.com>2023-09-20 07:48:55 -0400
commit7ec0a141b4bdda0fa67cb1f2af7f321b9963f0b8 (patch)
tree8a500d3fce31d337d50d43c75df70d243371acb2 /tests
parentReturn immutable objects for cachedList decorators (#16350) (diff)
downloadsynapse-7ec0a141b4bdda0fa67cb1f2af7f321b9963f0b8.tar.xz
Convert more cached return values to immutable types (#16356)
Diffstat (limited to 'tests')
-rw-r--r--tests/util/caches/test_descriptors.py35
1 files changed, 19 insertions, 16 deletions
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 168419f440..7e8725e610 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -15,10 +15,10 @@
 import logging
 from typing import (
     Any,
-    Dict,
     Generator,
     Iterable,
     List,
+    Mapping,
     NoReturn,
     Optional,
     Set,
@@ -96,7 +96,7 @@ class DescriptorTestCase(unittest.TestCase):
                 self.mock = mock.Mock()
 
             @descriptors.cached(num_args=1)
-            def fn(self, arg1: int, arg2: int) -> mock.Mock:
+            def fn(self, arg1: int, arg2: int) -> str:
                 return self.mock(arg1, arg2)
 
         obj = Cls()
@@ -228,8 +228,9 @@ class DescriptorTestCase(unittest.TestCase):
             call_count = 0
 
             @cached()
-            def fn(self, arg1: int) -> Optional[Deferred]:
+            def fn(self, arg1: int) -> Deferred:
                 self.call_count += 1
+                assert self.result is not None
                 return self.result
 
         obj = Cls()
@@ -401,21 +402,21 @@ class DescriptorTestCase(unittest.TestCase):
                 self.mock = mock.Mock()
 
             @descriptors.cached(iterable=True)
-            def fn(self, arg1: int, arg2: int) -> List[str]:
+            def fn(self, arg1: int, arg2: int) -> Tuple[str, ...]:
                 return self.mock(arg1, arg2)
 
         obj = Cls()
 
-        obj.mock.return_value = ["spam", "eggs"]
+        obj.mock.return_value = ("spam", "eggs")
         r = obj.fn(1, 2)
-        self.assertEqual(r.result, ["spam", "eggs"])
+        self.assertEqual(r.result, ("spam", "eggs"))
         obj.mock.assert_called_once_with(1, 2)
         obj.mock.reset_mock()
 
         # a call with different params should call the mock again
-        obj.mock.return_value = ["chips"]
+        obj.mock.return_value = ("chips",)
         r = obj.fn(1, 3)
-        self.assertEqual(r.result, ["chips"])
+        self.assertEqual(r.result, ("chips",))
         obj.mock.assert_called_once_with(1, 3)
         obj.mock.reset_mock()
 
@@ -423,9 +424,9 @@ class DescriptorTestCase(unittest.TestCase):
         self.assertEqual(len(obj.fn.cache.cache), 3)
 
         r = obj.fn(1, 2)
-        self.assertEqual(r.result, ["spam", "eggs"])
+        self.assertEqual(r.result, ("spam", "eggs"))
         r = obj.fn(1, 3)
-        self.assertEqual(r.result, ["chips"])
+        self.assertEqual(r.result, ("chips",))
         obj.mock.assert_not_called()
 
     def test_cache_iterable_with_sync_exception(self) -> None:
@@ -784,7 +785,9 @@ class CachedListDescriptorTestCase(unittest.TestCase):
                 pass
 
             @descriptors.cachedList(cached_method_name="fn", list_name="args1")
-            async def list_fn(self, args1: Iterable[int], arg2: int) -> Dict[int, str]:
+            async def list_fn(
+                self, args1: Iterable[int], arg2: int
+            ) -> Mapping[int, str]:
                 context = current_context()
                 assert isinstance(context, LoggingContext)
                 assert context.name == "c1"
@@ -847,11 +850,11 @@ class CachedListDescriptorTestCase(unittest.TestCase):
                 pass
 
             @descriptors.cachedList(cached_method_name="fn", list_name="args1")
-            def list_fn(self, args1: List[int]) -> "Deferred[dict]":
+            def list_fn(self, args1: List[int]) -> "Deferred[Mapping[int, str]]":
                 return self.mock(args1)
 
         obj = Cls()
-        deferred_result: "Deferred[dict]" = Deferred()
+        deferred_result: "Deferred[Mapping[int, str]]" = Deferred()
         obj.mock.return_value = deferred_result
 
         # start off several concurrent lookups of the same key
@@ -890,7 +893,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
                 pass
 
             @descriptors.cachedList(cached_method_name="fn", list_name="args1")
-            async def list_fn(self, args1: List[int], arg2: int) -> Dict[int, str]:
+            async def list_fn(self, args1: List[int], arg2: int) -> Mapping[int, str]:
                 # we want this to behave like an asynchronous function
                 await run_on_reactor()
                 return self.mock(args1, arg2)
@@ -929,7 +932,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
                 pass
 
             @cachedList(cached_method_name="fn", list_name="args")
-            async def list_fn(self, args: List[int]) -> Dict[int, str]:
+            async def list_fn(self, args: List[int]) -> Mapping[int, str]:
                 await complete_lookup
                 return {arg: str(arg) for arg in args}
 
@@ -964,7 +967,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
                 pass
 
             @cachedList(cached_method_name="fn", list_name="args")
-            async def list_fn(self, args: List[int]) -> Dict[int, str]:
+            async def list_fn(self, args: List[int]) -> Mapping[int, str]:
                 await make_deferred_yieldable(complete_lookup)
                 self.inner_context_was_finished = current_context().finished
                 return {arg: str(arg) for arg in args}