diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 70d52b088c..28c767ecfd 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -992,3 +992,112 @@ class StateFilterDifferenceTestCase(TestCase):
StateFilter.none(),
StateFilter.all(),
)
+
+
+class StateFilterTestCase(TestCase):
+ def test_return_expanded(self):
+ """
+ Tests the behaviour of the return_expanded() function that expands
+ StateFilters to include more state types (for the sake of cache hit rate).
+ """
+
+ self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())
+
+ self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())
+
+ # Concrete-only state filters stay the same
+ # (Case: mixed filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ "some.other.state.type": {""},
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ "some.other.state.type": {""},
+ },
+ include_others=False,
+ ),
+ )
+
+ # Concrete-only state filters stay the same
+ # (Case: non-member-only filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {"some.other.state.type": {""}}, include_others=False
+ ).return_expanded(),
+ StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
+ )
+
+ # Concrete-only state filters stay the same
+ # (Case: member-only filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ },
+ include_others=False,
+ ),
+ )
+
+ # Wildcard member-only state filters stay the same
+ self.assertEqual(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ )
+
+ # If there is a wildcard in the non-member portion of the filter,
+ # it's expanded to include ALL non-member events.
+ # (Case: mixed filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ "some.other.state.type": None,
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze(
+ {EventTypes.Member: {"@wombat:test", "@alicia:test"}},
+ include_others=True,
+ ),
+ )
+
+ # If there is a wildcard in the non-member portion of the filter,
+ # it's expanded to include ALL non-member events.
+ # (Case: non-member-only filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ "some.other.state.type": None,
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
+ )
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ "some.other.state.type": None,
+ "yet.another.state.type": {"wombat"},
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
+ )
|