diff --git a/changelog.d/12016.misc b/changelog.d/12016.misc
new file mode 100644
index 0000000000..8856ef46a9
--- /dev/null
+++ b/changelog.d/12016.misc
@@ -0,0 +1 @@
+Fix bug in `StateFilter.return_expanded()` and add some tests.
\ No newline at end of file
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 913448f0f9..e79ecf64a0 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -204,13 +204,16 @@ class StateFilter:
if get_all_members:
# We want to return everything.
return StateFilter.all()
- else:
+ elif EventTypes.Member in self.types:
# We want to return all non-members, but only particular
# memberships
return StateFilter(
types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
include_others=True,
)
+ else:
+ # We want to return all non-members
+ return _ALL_NON_MEMBER_STATE_FILTER
def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
"""Converts the filter to an SQL clause.
@@ -528,6 +531,9 @@ class StateFilter:
_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True)
+_ALL_NON_MEMBER_STATE_FILTER = StateFilter(
+ types=frozendict({EventTypes.Member: frozenset()}), include_others=True
+)
_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 70d52b088c..28c767ecfd 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -992,3 +992,112 @@ class StateFilterDifferenceTestCase(TestCase):
StateFilter.none(),
StateFilter.all(),
)
+
+
+class StateFilterTestCase(TestCase):
+ def test_return_expanded(self):
+ """
+ Tests the behaviour of the return_expanded() function that expands
+ StateFilters to include more state types (for the sake of cache hit rate).
+ """
+
+ self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())
+
+ self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())
+
+ # Concrete-only state filters stay the same
+ # (Case: mixed filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ "some.other.state.type": {""},
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ "some.other.state.type": {""},
+ },
+ include_others=False,
+ ),
+ )
+
+ # Concrete-only state filters stay the same
+ # (Case: non-member-only filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {"some.other.state.type": {""}}, include_others=False
+ ).return_expanded(),
+ StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
+ )
+
+ # Concrete-only state filters stay the same
+ # (Case: member-only filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ },
+ include_others=False,
+ ),
+ )
+
+ # Wildcard member-only state filters stay the same
+ self.assertEqual(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ )
+
+ # If there is a wildcard in the non-member portion of the filter,
+ # it's expanded to include ALL non-member events.
+ # (Case: mixed filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ "some.other.state.type": None,
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze(
+ {EventTypes.Member: {"@wombat:test", "@alicia:test"}},
+ include_others=True,
+ ),
+ )
+
+ # If there is a wildcard in the non-member portion of the filter,
+ # it's expanded to include ALL non-member events.
+ # (Case: non-member-only filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ "some.other.state.type": None,
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
+ )
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ "some.other.state.type": None,
+ "yet.another.state.type": {"wombat"},
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
+ )
|