summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2022-12-13 00:54:46 +0000
committerGitHub <noreply@github.com>2022-12-13 00:54:46 +0000
commite2a1adbf5d11288f2134ced1f84c6ffdd91a9357 (patch)
treea7e8ec0eee2585f55b6f275425a4007a29b6372e /tests
parentEnable `--warn-redundant-casts` option in mypy (#14671) (diff)
downloadsynapse-e2a1adbf5d11288f2134ced1f84c6ffdd91a9357.tar.xz
Allow selecting "prejoin" events by state keys (#14642)
* Declare new config

* Parse new config

* Read new config

* Don't use trial/our TestCase where it's not needed

Before:

```
$ time trial tests/events/test_utils.py > /dev/null

real	0m2.277s
user	0m2.186s
sys	0m0.083s
```

After:
```
$ time trial tests/events/test_utils.py > /dev/null

real	0m0.566s
user	0m0.508s
sys	0m0.056s
```

* Helper to upsert to event fields

without exceeding size limits.

* Use helper when adding invite/knock state

Now that we allow admins to include events in prejoin room state with
arbitrary state keys, be a good Matrix citizen and ensure they don't
accidentally create an oversized event.

* Changelog

* Move StateFilter tests

should have done this in #14668

* Add extra methods to StateFilter

* Use StateFilter

* Ensure test file enforces typed defs; alphabetise

* Workaround surprising get_current_state_ids

* Whoops, fix mypy
Diffstat (limited to 'tests')
-rw-r--r--tests/config/test_api.py145
-rw-r--r--tests/events/test_utils.py35
-rw-r--r--tests/storage/test_state.py623
-rw-r--r--tests/types/__init__.py0
-rw-r--r--tests/types/test_state.py627
5 files changed, 803 insertions, 627 deletions
diff --git a/tests/config/test_api.py b/tests/config/test_api.py
new file mode 100644
index 0000000000..6773c9a277
--- /dev/null
+++ b/tests/config/test_api.py
@@ -0,0 +1,145 @@
+from unittest import TestCase as StdlibTestCase
+
+import yaml
+
+from synapse.config import ConfigError
+from synapse.config.api import ApiConfig
+from synapse.types.state import StateFilter
+
+DEFAULT_PREJOIN_STATE_PAIRS = {
+    ("m.room.join_rules", ""),
+    ("m.room.canonical_alias", ""),
+    ("m.room.avatar", ""),
+    ("m.room.encryption", ""),
+    ("m.room.name", ""),
+    ("m.room.create", ""),
+    ("m.room.topic", ""),
+}
+
+
+class TestRoomPrejoinState(StdlibTestCase):
+    def read_config(self, source: str) -> ApiConfig:
+        config = ApiConfig()
+        config.read_config(yaml.safe_load(source))
+        return config
+
+    def test_no_prejoin_state(self) -> None:
+        config = self.read_config("foo: bar")
+        self.assertFalse(config.room_prejoin_state.has_wildcards())
+        self.assertEqual(
+            set(config.room_prejoin_state.concrete_types()), DEFAULT_PREJOIN_STATE_PAIRS
+        )
+
+    def test_disable_default_event_types(self) -> None:
+        config = self.read_config(
+            """
+room_prejoin_state:
+    disable_default_event_types: true
+        """
+        )
+        self.assertEqual(config.room_prejoin_state, StateFilter.none())
+
+    def test_event_without_state_key(self) -> None:
+        config = self.read_config(
+            """
+room_prejoin_state:
+    disable_default_event_types: true
+    additional_event_types:
+        - foo
+        """
+        )
+        self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
+        self.assertEqual(config.room_prejoin_state.concrete_types(), [])
+
+    def test_event_with_specific_state_key(self) -> None:
+        config = self.read_config(
+            """
+room_prejoin_state:
+    disable_default_event_types: true
+    additional_event_types:
+        - [foo, bar]
+        """
+        )
+        self.assertFalse(config.room_prejoin_state.has_wildcards())
+        self.assertEqual(
+            set(config.room_prejoin_state.concrete_types()),
+            {("foo", "bar")},
+        )
+
+    def test_repeated_event_with_specific_state_key(self) -> None:
+        config = self.read_config(
+            """
+room_prejoin_state:
+    disable_default_event_types: true
+    additional_event_types:
+        - [foo, bar]
+        - [foo, baz]
+        """
+        )
+        self.assertFalse(config.room_prejoin_state.has_wildcards())
+        self.assertEqual(
+            set(config.room_prejoin_state.concrete_types()),
+            {("foo", "bar"), ("foo", "baz")},
+        )
+
+    def test_no_specific_state_key_overrides_specific_state_key(self) -> None:
+        config = self.read_config(
+            """
+room_prejoin_state:
+    disable_default_event_types: true
+    additional_event_types:
+        - [foo, bar]
+        - foo
+        """
+        )
+        self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
+        self.assertEqual(config.room_prejoin_state.concrete_types(), [])
+
+        config = self.read_config(
+            """
+room_prejoin_state:
+    disable_default_event_types: true
+    additional_event_types:
+        - foo
+        - [foo, bar]
+        """
+        )
+        self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
+        self.assertEqual(config.room_prejoin_state.concrete_types(), [])
+
+    def test_bad_event_type_entry_raises(self) -> None:
+        with self.assertRaises(ConfigError):
+            self.read_config(
+                """
+room_prejoin_state:
+    additional_event_types:
+        - []
+            """
+            )
+
+        with self.assertRaises(ConfigError):
+            self.read_config(
+                """
+room_prejoin_state:
+    additional_event_types:
+        - [a]
+            """
+            )
+
+        with self.assertRaises(ConfigError):
+            self.read_config(
+                """
+room_prejoin_state:
+    additional_event_types:
+        - [a, b, c]
+            """
+            )
+
+        with self.assertRaises(ConfigError):
+            self.read_config(
+                """
+room_prejoin_state:
+    additional_event_types:
+        - [true, 1.23]
+            """
+            )
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index b1c47efac7..a79256846f 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -12,19 +12,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import unittest as stdlib_unittest
+
 from synapse.api.constants import EventContentFields
 from synapse.api.room_versions import RoomVersions
 from synapse.events import make_event_from_dict
 from synapse.events.utils import (
     SerializeEventConfig,
     copy_and_fixup_power_levels_contents,
+    maybe_upsert_event_field,
     prune_event,
     serialize_event,
 )
 from synapse.util.frozenutils import freeze
 
-from tests import unittest
-
 
 def MockEvent(**kwargs):
     if "event_id" not in kwargs:
@@ -34,7 +35,31 @@ def MockEvent(**kwargs):
     return make_event_from_dict(kwargs)
 
 
-class PruneEventTestCase(unittest.TestCase):
+class TestMaybeUpsertEventField(stdlib_unittest.TestCase):
+    def test_update_okay(self) -> None:
+        event = make_event_from_dict({"event_id": "$1234"})
+        success = maybe_upsert_event_field(event, event.unsigned, "key", "value")
+        self.assertTrue(success)
+        self.assertEqual(event.unsigned["key"], "value")
+
+    def test_update_not_okay(self) -> None:
+        event = make_event_from_dict({"event_id": "$1234"})
+        LARGE_STRING = "a" * 100_000
+        success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING)
+        self.assertFalse(success)
+        self.assertNotIn("key", event.unsigned)
+
+    def test_update_not_okay_leaves_original_value(self) -> None:
+        event = make_event_from_dict(
+            {"event_id": "$1234", "unsigned": {"key": "value"}}
+        )
+        LARGE_STRING = "a" * 100_000
+        success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING)
+        self.assertFalse(success)
+        self.assertEqual(event.unsigned["key"], "value")
+
+
+class PruneEventTestCase(stdlib_unittest.TestCase):
     def run_test(self, evdict, matchdict, **kwargs):
         """
         Asserts that a new event constructed with `evdict` will look like
@@ -391,7 +416,7 @@ class PruneEventTestCase(unittest.TestCase):
         )
 
 
-class SerializeEventTestCase(unittest.TestCase):
+class SerializeEventTestCase(stdlib_unittest.TestCase):
     def serialize(self, ev, fields):
         return serialize_event(
             ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields)
@@ -513,7 +538,7 @@ class SerializeEventTestCase(unittest.TestCase):
             )
 
 
-class CopyPowerLevelsContentTestCase(unittest.TestCase):
+class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase):
     def setUp(self) -> None:
         self.test_content = {
             "ban": 50,
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index a433e70870..bad7f0bc60 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -26,7 +26,7 @@ from synapse.types import JsonDict, RoomID, StateMap, UserID
 from synapse.types.state import StateFilter
 from synapse.util import Clock
 
-from tests.unittest import HomeserverTestCase, TestCase
+from tests.unittest import HomeserverTestCase
 
 logger = logging.getLogger(__name__)
 
@@ -494,624 +494,3 @@ class StateStoreTestCase(HomeserverTestCase):
 
         self.assertEqual(is_all, True)
         self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
-
-
-class StateFilterDifferenceTestCase(TestCase):
-    def assert_difference(
-        self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
-    ) -> None:
-        self.assertEqual(
-            minuend.approx_difference(subtrahend),
-            expected,
-            f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
-        )
-
-    def test_state_filter_difference_no_include_other_minus_no_include_other(
-        self,
-    ) -> None:
-        """
-        Tests the StateFilter.approx_difference method
-        where, in a.approx_difference(b), both a and b do not have the
-        include_others flag set.
-        """
-        # (wildcard on state keys) - (wildcard on state keys):
-        self.assert_difference(
-            StateFilter.freeze(
-                {EventTypes.Member: None, EventTypes.Create: None},
-                include_others=False,
-            ),
-            StateFilter.freeze(
-                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
-                include_others=False,
-            ),
-            StateFilter.freeze({EventTypes.Create: None}, include_others=False),
-        )
-
-        # (wildcard on state keys) - (specific state keys)
-        # This one is an over-approximation because we can't represent
-        # 'all state keys except a few named examples'
-        self.assert_difference(
-            StateFilter.freeze({EventTypes.Member: None}, include_others=False),
-            StateFilter.freeze(
-                {EventTypes.Member: {"@wombat:spqr"}},
-                include_others=False,
-            ),
-            StateFilter.freeze({EventTypes.Member: None}, include_others=False),
-        )
-
-        # (wildcard on state keys) - (no state keys)
-        self.assert_difference(
-            StateFilter.freeze(
-                {EventTypes.Member: None},
-                include_others=False,
-            ),
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: set(),
-                },
-                include_others=False,
-            ),
-            StateFilter.freeze(
-                {EventTypes.Member: None},
-                include_others=False,
-            ),
-        )
-
-        # (specific state keys) - (wildcard on state keys):
-        self.assert_difference(
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
-                    EventTypes.CanonicalAlias: {""},
-                },
-                include_others=False,
-            ),
-            StateFilter.freeze(
-                {EventTypes.Member: None},
-                include_others=False,
-            ),
-            StateFilter.freeze(
-                {EventTypes.CanonicalAlias: {""}},
-                include_others=False,
-            ),
-        )
-
-        # (specific state keys) - (specific state keys)
-        self.assert_difference(
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
-                    EventTypes.CanonicalAlias: {""},
-                },
-                include_others=False,
-            ),
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:spqr"},
-                },
-                include_others=False,
-            ),
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@spqr:spqr"},
-                    EventTypes.CanonicalAlias: {""},
-                },
-                include_others=False,
-            ),
-        )
-
-        # (specific state keys) - (no state keys)
-        self.assert_difference(
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
-                    EventTypes.CanonicalAlias: {""},
-                },
-                include_others=False,
-            ),
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: set(),
-                },
-                include_others=False,
-            ),
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
-                    EventTypes.CanonicalAlias: {""},
-                },
-                include_others=False,
-            ),
-        )
-
-    def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
-        """
-        Tests the StateFilter.approx_difference method
-        where, in a.approx_difference(b), only a has the include_others flag set.
-        """
-        # (wildcard on state keys) - (wildcard on state keys):
-        self.assert_difference(
-            StateFilter.freeze(
-                {EventTypes.Member: None, EventTypes.Create: None},
-                include_others=True,
-            ),
-            StateFilter.freeze(
-                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
-                include_others=False,
-            ),
-            StateFilter.freeze(
-                {
-                    EventTypes.Create: None,
-                    EventTypes.Member: set(),
-                    EventTypes.CanonicalAlias: set(),
-                },
-                include_others=True,
-            ),
-        )
-
-        # (wildcard on state keys) - (specific state keys)
-        # This one is an over-approximation because we can't represent
-        # 'all state keys except a few named examples'
-        # This also shows that the resultant state filter is normalised.
-        self.assert_difference(
-            StateFilter.freeze({EventTypes.Member: None}, include_others=True),
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:spqr"},
-                    EventTypes.Create: {""},
-                },
-                include_others=False,
-            ),
-            StateFilter(types=frozendict(), include_others=True),
-        )
-
-        # (wildcard on state keys) - (no state keys)
-        self.assert_difference(
-            StateFilter.freeze(
-                {EventTypes.Member: None},
-                include_others=True,
-            ),
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: set(),
-                },
-                include_others=False,
-            ),
-            StateFilter(
-                types=frozendict(),
-                include_others=True,
-            ),
-        )
-
-        # (specific state keys) - (wildcard on state keys):
-        self.assert_difference(
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
-                    EventTypes.CanonicalAlias: {""},
-                },
-                include_others=True,
-            ),
-            StateFilter.freeze(
-                {EventTypes.Member: None},
-                include_others=False,
-            ),
-            StateFilter.freeze(
-                {
-                    EventTypes.CanonicalAlias: {""},
-                    EventTypes.Member: set(),
-                },
-                include_others=True,
-            ),
-        )
-
-        # (specific state keys) - (specific state keys)
-        self.assert_difference(
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
-                    EventTypes.CanonicalAlias: {""},
-                },
-                include_others=True,
-            ),
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:spqr"},
-                },
-                include_others=False,
-            ),
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@spqr:spqr"},
-                    EventTypes.CanonicalAlias: {""},
-                },
-                include_others=True,
-            ),
-        )
-
-        # (specific state keys) - (no state keys)
-        self.assert_difference(
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
-                    EventTypes.CanonicalAlias: {""},
-                },
-                include_others=True,
-            ),
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: set(),
-                },
-                include_others=False,
-            ),
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
-                    EventTypes.CanonicalAlias: {""},
-                },
-                include_others=True,
-            ),
-        )
-
-    def test_state_filter_difference_include_other_minus_include_other(self) -> None:
-        """
-        Tests the StateFilter.approx_difference method
-        where, in a.approx_difference(b), both a and b have the include_others
-        flag set.
-        """
-        # (wildcard on state keys) - (wildcard on state keys):
-        self.assert_difference(
-            StateFilter.freeze(
-                {EventTypes.Member: None, EventTypes.Create: None},
-                include_others=True,
-            ),
-            StateFilter.freeze(
-                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
-                include_others=True,
-            ),
-            StateFilter(types=frozendict(), include_others=False),
-        )
-
-        # (wildcard on state keys) - (specific state keys)
-        # This one is an over-approximation because we can't represent
-        # 'all state keys except a few named examples'
-        self.assert_difference(
-            StateFilter.freeze({EventTypes.Member: None}, include_others=True),
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:spqr"},
-                    EventTypes.CanonicalAlias: {""},
-                },
-                include_others=True,
-            ),
-            StateFilter.freeze(
-                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
-                include_others=False,
-            ),
-        )
-
-        # (wildcard on state keys) - (no state keys)
-        self.assert_difference(
-            StateFilter.freeze(
-                {EventTypes.Member: None},
-                include_others=True,
-            ),
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: set(),
-                },
-                include_others=True,
-            ),
-            StateFilter.freeze(
-                {EventTypes.Member: None},
-                include_others=False,
-            ),
-        )
-
-        # (specific state keys) - (wildcard on state keys):
-        self.assert_difference(
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
-                    EventTypes.CanonicalAlias: {""},
-                },
-                include_others=True,
-            ),
-            StateFilter.freeze(
-                {EventTypes.Member: None},
-                include_others=True,
-            ),
-            StateFilter(
-                types=frozendict(),
-                include_others=False,
-            ),
-        )
-
-        # (specific state keys) - (specific state keys)
-        # This one is an over-approximation because we can't represent
-        # 'all state keys except a few named examples'
-        self.assert_difference(
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
-                    EventTypes.CanonicalAlias: {""},
-                    EventTypes.Create: {""},
-                },
-                include_others=True,
-            ),
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:spqr"},
-                    EventTypes.Create: set(),
-                },
-                include_others=True,
-            ),
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@spqr:spqr"},
-                    EventTypes.Create: {""},
-                },
-                include_others=False,
-            ),
-        )
-
-        # (specific state keys) - (no state keys)
-        self.assert_difference(
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
-                    EventTypes.CanonicalAlias: {""},
-                },
-                include_others=True,
-            ),
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: set(),
-                },
-                include_others=True,
-            ),
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
-                },
-                include_others=False,
-            ),
-        )
-
-    def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
-        """
-        Tests the StateFilter.approx_difference method
-        where, in a.approx_difference(b), only b has the include_others flag set.
-        """
-        # (wildcard on state keys) - (wildcard on state keys):
-        self.assert_difference(
-            StateFilter.freeze(
-                {EventTypes.Member: None, EventTypes.Create: None},
-                include_others=False,
-            ),
-            StateFilter.freeze(
-                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
-                include_others=True,
-            ),
-            StateFilter(types=frozendict(), include_others=False),
-        )
-
-        # (wildcard on state keys) - (specific state keys)
-        # This one is an over-approximation because we can't represent
-        # 'all state keys except a few named examples'
-        self.assert_difference(
-            StateFilter.freeze({EventTypes.Member: None}, include_others=False),
-            StateFilter.freeze(
-                {EventTypes.Member: {"@wombat:spqr"}},
-                include_others=True,
-            ),
-            StateFilter.freeze({EventTypes.Member: None}, include_others=False),
-        )
-
-        # (wildcard on state keys) - (no state keys)
-        self.assert_difference(
-            StateFilter.freeze(
-                {EventTypes.Member: None},
-                include_others=False,
-            ),
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: set(),
-                },
-                include_others=True,
-            ),
-            StateFilter.freeze(
-                {EventTypes.Member: None},
-                include_others=False,
-            ),
-        )
-
-        # (specific state keys) - (wildcard on state keys):
-        self.assert_difference(
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
-                    EventTypes.CanonicalAlias: {""},
-                },
-                include_others=False,
-            ),
-            StateFilter.freeze(
-                {EventTypes.Member: None},
-                include_others=True,
-            ),
-            StateFilter(
-                types=frozendict(),
-                include_others=False,
-            ),
-        )
-
-        # (specific state keys) - (specific state keys)
-        # This one is an over-approximation because we can't represent
-        # 'all state keys except a few named examples'
-        self.assert_difference(
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
-                    EventTypes.CanonicalAlias: {""},
-                },
-                include_others=False,
-            ),
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:spqr"},
-                },
-                include_others=True,
-            ),
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@spqr:spqr"},
-                },
-                include_others=False,
-            ),
-        )
-
-        # (specific state keys) - (no state keys)
-        self.assert_difference(
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
-                    EventTypes.CanonicalAlias: {""},
-                },
-                include_others=False,
-            ),
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: set(),
-                },
-                include_others=True,
-            ),
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
-                },
-                include_others=False,
-            ),
-        )
-
-    def test_state_filter_difference_simple_cases(self) -> None:
-        """
-        Tests some very simple cases of the StateFilter approx_difference,
-        that are not explicitly tested by the more in-depth tests.
-        """
-
-        self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())
-
-        self.assert_difference(
-            StateFilter.all(),
-            StateFilter.none(),
-            StateFilter.all(),
-        )
-
-
-class StateFilterTestCase(TestCase):
-    def test_return_expanded(self) -> None:
-        """
-        Tests the behaviour of the return_expanded() function that expands
-        StateFilters to include more state types (for the sake of cache hit rate).
-        """
-
-        self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())
-
-        self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())
-
-        # Concrete-only state filters stay the same
-        # (Case: mixed filter)
-        self.assertEqual(
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:test", "@alicia:test"},
-                    "some.other.state.type": {""},
-                },
-                include_others=False,
-            ).return_expanded(),
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:test", "@alicia:test"},
-                    "some.other.state.type": {""},
-                },
-                include_others=False,
-            ),
-        )
-
-        # Concrete-only state filters stay the same
-        # (Case: non-member-only filter)
-        self.assertEqual(
-            StateFilter.freeze(
-                {"some.other.state.type": {""}}, include_others=False
-            ).return_expanded(),
-            StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
-        )
-
-        # Concrete-only state filters stay the same
-        # (Case: member-only filter)
-        self.assertEqual(
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:test", "@alicia:test"},
-                },
-                include_others=False,
-            ).return_expanded(),
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:test", "@alicia:test"},
-                },
-                include_others=False,
-            ),
-        )
-
-        # Wildcard member-only state filters stay the same
-        self.assertEqual(
-            StateFilter.freeze(
-                {EventTypes.Member: None},
-                include_others=False,
-            ).return_expanded(),
-            StateFilter.freeze(
-                {EventTypes.Member: None},
-                include_others=False,
-            ),
-        )
-
-        # If there is a wildcard in the non-member portion of the filter,
-        # it's expanded to include ALL non-member events.
-        # (Case: mixed filter)
-        self.assertEqual(
-            StateFilter.freeze(
-                {
-                    EventTypes.Member: {"@wombat:test", "@alicia:test"},
-                    "some.other.state.type": None,
-                },
-                include_others=False,
-            ).return_expanded(),
-            StateFilter.freeze(
-                {EventTypes.Member: {"@wombat:test", "@alicia:test"}},
-                include_others=True,
-            ),
-        )
-
-        # If there is a wildcard in the non-member portion of the filter,
-        # it's expanded to include ALL non-member events.
-        # (Case: non-member-only filter)
-        self.assertEqual(
-            StateFilter.freeze(
-                {
-                    "some.other.state.type": None,
-                },
-                include_others=False,
-            ).return_expanded(),
-            StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
-        )
-        self.assertEqual(
-            StateFilter.freeze(
-                {
-                    "some.other.state.type": None,
-                    "yet.another.state.type": {"wombat"},
-                },
-                include_others=False,
-            ).return_expanded(),
-            StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
-        )
diff --git a/tests/types/__init__.py b/tests/types/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/types/__init__.py
diff --git a/tests/types/test_state.py b/tests/types/test_state.py
new file mode 100644
index 0000000000..eb809f9fb7
--- /dev/null
+++ b/tests/types/test_state.py
@@ -0,0 +1,627 @@
+from frozendict import frozendict
+
+from synapse.api.constants import EventTypes
+from synapse.types.state import StateFilter
+
+from tests.unittest import TestCase
+
+
+class StateFilterDifferenceTestCase(TestCase):
+    def assert_difference(
+        self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
+    ) -> None:
+        self.assertEqual(
+            minuend.approx_difference(subtrahend),
+            expected,
+            f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
+        )
+
+    def test_state_filter_difference_no_include_other_minus_no_include_other(
+        self,
+    ) -> None:
+        """
+        Tests the StateFilter.approx_difference method
+        where, in a.approx_difference(b), both a and b do not have the
+        include_others flag set.
+        """
+        # (wildcard on state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.Create: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+                include_others=False,
+            ),
+            StateFilter.freeze({EventTypes.Create: None}, include_others=False),
+        )
+
+        # (wildcard on state keys) - (specific state keys)
+        # This one is an over-approximation because we can't represent
+        # 'all state keys except a few named examples'
+        self.assert_difference(
+            StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+            StateFilter.freeze(
+                {EventTypes.Member: {"@wombat:spqr"}},
+                include_others=False,
+            ),
+            StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+        )
+
+        # (wildcard on state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {EventTypes.CanonicalAlias: {""}},
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (specific state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr"},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+        )
+
+    def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
+        """
+        Tests the StateFilter.approx_difference method
+        where, in a.approx_difference(b), only a has the include_others flag set.
+        """
+        # (wildcard on state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.Create: None},
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Create: None,
+                    EventTypes.Member: set(),
+                    EventTypes.CanonicalAlias: set(),
+                },
+                include_others=True,
+            ),
+        )
+
+        # (wildcard on state keys) - (specific state keys)
+        # This one is an over-approximation because we can't represent
+        # 'all state keys except a few named examples'
+        # This also shows that the resultant state filter is normalised.
+        self.assert_difference(
+            StateFilter.freeze({EventTypes.Member: None}, include_others=True),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr"},
+                    EventTypes.Create: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter(types=frozendict(), include_others=True),
+        )
+
+        # (wildcard on state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=False,
+            ),
+            StateFilter(
+                types=frozendict(),
+                include_others=True,
+            ),
+        )
+
+        # (specific state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.CanonicalAlias: {""},
+                    EventTypes.Member: set(),
+                },
+                include_others=True,
+            ),
+        )
+
+        # (specific state keys) - (specific state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr"},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+        )
+
+        # (specific state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+        )
+
+    def test_state_filter_difference_include_other_minus_include_other(self) -> None:
+        """
+        Tests the StateFilter.approx_difference method
+        where, in a.approx_difference(b), both a and b have the include_others
+        flag set.
+        """
+        # (wildcard on state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.Create: None},
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+                include_others=True,
+            ),
+            StateFilter(types=frozendict(), include_others=False),
+        )
+
+        # (wildcard on state keys) - (specific state keys)
+        # This one is an over-approximation because we can't represent
+        # 'all state keys except a few named examples'
+        self.assert_difference(
+            StateFilter.freeze({EventTypes.Member: None}, include_others=True),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+                include_others=False,
+            ),
+        )
+
+        # (wildcard on state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=True,
+            ),
+            StateFilter(
+                types=frozendict(),
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (specific state keys)
+        # This one is an over-approximation because we can't represent
+        # 'all state keys except a few named examples'
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                    EventTypes.Create: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr"},
+                    EventTypes.Create: set(),
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@spqr:spqr"},
+                    EventTypes.Create: {""},
+                },
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                },
+                include_others=False,
+            ),
+        )
+
+    def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
+        """
+        Tests the StateFilter.approx_difference method
+        where, in a.approx_difference(b), only b has the include_others flag set.
+        """
+        # (wildcard on state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.Create: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+                include_others=True,
+            ),
+            StateFilter(types=frozendict(), include_others=False),
+        )
+
+        # (wildcard on state keys) - (specific state keys)
+        # This one is an over-approximation because we can't represent
+        # 'all state keys except a few named examples'
+        self.assert_difference(
+            StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+            StateFilter.freeze(
+                {EventTypes.Member: {"@wombat:spqr"}},
+                include_others=True,
+            ),
+            StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+        )
+
+        # (wildcard on state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=True,
+            ),
+            StateFilter(
+                types=frozendict(),
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (specific state keys)
+        # This one is an over-approximation because we can't represent
+        # 'all state keys except a few named examples'
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr"},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@spqr:spqr"},
+                },
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                },
+                include_others=False,
+            ),
+        )
+
+    def test_state_filter_difference_simple_cases(self) -> None:
+        """
+        Tests some very simple cases of the StateFilter approx_difference,
+        that are not explicitly tested by the more in-depth tests.
+        """
+
+        self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())
+
+        self.assert_difference(
+            StateFilter.all(),
+            StateFilter.none(),
+            StateFilter.all(),
+        )
+
+
+class StateFilterTestCase(TestCase):
+    def test_return_expanded(self) -> None:
+        """
+        Tests the behaviour of the return_expanded() function that expands
+        StateFilters to include more state types (for the sake of cache hit rate).
+        """
+
+        self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())
+
+        self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())
+
+        # Concrete-only state filters stay the same
+        # (Case: mixed filter)
+        self.assertEqual(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:test", "@alicia:test"},
+                    "some.other.state.type": {""},
+                },
+                include_others=False,
+            ).return_expanded(),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:test", "@alicia:test"},
+                    "some.other.state.type": {""},
+                },
+                include_others=False,
+            ),
+        )
+
+        # Concrete-only state filters stay the same
+        # (Case: non-member-only filter)
+        self.assertEqual(
+            StateFilter.freeze(
+                {"some.other.state.type": {""}}, include_others=False
+            ).return_expanded(),
+            StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
+        )
+
+        # Concrete-only state filters stay the same
+        # (Case: member-only filter)
+        self.assertEqual(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:test", "@alicia:test"},
+                },
+                include_others=False,
+            ).return_expanded(),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:test", "@alicia:test"},
+                },
+                include_others=False,
+            ),
+        )
+
+        # Wildcard member-only state filters stay the same
+        self.assertEqual(
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ).return_expanded(),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+        )
+
+        # If there is a wildcard in the non-member portion of the filter,
+        # it's expanded to include ALL non-member events.
+        # (Case: mixed filter)
+        self.assertEqual(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:test", "@alicia:test"},
+                    "some.other.state.type": None,
+                },
+                include_others=False,
+            ).return_expanded(),
+            StateFilter.freeze(
+                {EventTypes.Member: {"@wombat:test", "@alicia:test"}},
+                include_others=True,
+            ),
+        )
+
+        # If there is a wildcard in the non-member portion of the filter,
+        # it's expanded to include ALL non-member events.
+        # (Case: non-member-only filter)
+        self.assertEqual(
+            StateFilter.freeze(
+                {
+                    "some.other.state.type": None,
+                },
+                include_others=False,
+            ).return_expanded(),
+            StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
+        )
+        self.assertEqual(
+            StateFilter.freeze(
+                {
+                    "some.other.state.type": None,
+                    "yet.another.state.type": {"wombat"},
+                },
+                include_others=False,
+            ).return_expanded(),
+            StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
+        )