summary refs log tree commit diff
path: root/tests/util
diff options
context:
space:
mode:
Diffstat (limited to 'tests/util')
-rw-r--r--tests/util/caches/test_descriptors.py33
1 files changed, 32 insertions, 1 deletions
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 48e616ac74..90861fe522 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Set
+from typing import Iterable, Set, Tuple
 from unittest import mock
 
 from twisted.internet import defer, reactor
@@ -1008,3 +1008,34 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             obj.inner_context_was_finished, "Tried to restart a finished logcontext"
         )
         self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
+    def test_num_args_mismatch(self):
+        """
+        Make sure someone does not accidentally use @cachedList on a method with
+        a mismatch in the number args to the underlying single cache method.
+        """
+
+        class Cls:
+            @descriptors.cached(tree=True)
+            def fn(self, room_id, event_id):
+                pass
+
+            # This is wrong ❌. `@cachedList` expects to be given the same number
+            # of arguments as the underlying cached function, just with one of
+            # the arguments being an iterable
+            @descriptors.cachedList(cached_method_name="fn", list_name="keys")
+            def list_fn(self, keys: Iterable[Tuple[str, str]]):
+                pass
+
+            # Corrected syntax ✅
+            #
+            # @cachedList(cached_method_name="fn", list_name="event_ids")
+            # async def list_fn(
+            #     self, room_id: str, event_ids: Collection[str],
+            # )
+
+        obj = Cls()
+
+        # Make sure this raises an error about the arg mismatch
+        with self.assertRaises(Exception):
+            obj.list_fn([("foo", "bar")])