diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index dcc2b4be89..3f0d8139f8 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -383,3 +383,34 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase):
# the items should still be in the cache
self.assertEqual(cache.get("key1"), 1)
self.assertEqual(cache.get("key2"), 2)
+
+
+class ExtraIndexLruCacheTestCase(unittest.HomeserverTestCase):
+ def test_invalidate_simple(self) -> None:
+ cache: LruCache[str, int] = LruCache(10, extra_index_cb=lambda k, v: str(v))
+ cache["key1"] = 1
+ cache["key2"] = 2
+
+ cache.invalidate_on_extra_index("key1")
+ self.assertEqual(cache.get("key1"), 1)
+ self.assertEqual(cache.get("key2"), 2)
+
+ cache.invalidate_on_extra_index("1")
+ self.assertEqual(cache.get("key1"), None)
+ self.assertEqual(cache.get("key2"), 2)
+
+ def test_invalidate_multi(self) -> None:
+ cache: LruCache[str, int] = LruCache(10, extra_index_cb=lambda k, v: str(v))
+ cache["key1"] = 1
+ cache["key2"] = 1
+ cache["key3"] = 2
+
+ cache.invalidate_on_extra_index("key1")
+ self.assertEqual(cache.get("key1"), 1)
+ self.assertEqual(cache.get("key2"), 1)
+ self.assertEqual(cache.get("key3"), 2)
+
+ cache.invalidate_on_extra_index("1")
+ self.assertEqual(cache.get("key1"), None)
+ self.assertEqual(cache.get("key2"), None)
+ self.assertEqual(cache.get("key3"), 2)
|