summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/storage/test_state.py513
1 files changed, 511 insertions, 2 deletions
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 32060f2abd..70d52b088c 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -21,7 +21,7 @@ from synapse.api.room_versions import RoomVersions
 from synapse.storage.state import StateFilter
 from synapse.types import RoomID, UserID
 
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, TestCase
 
 logger = logging.getLogger(__name__)
 
@@ -105,7 +105,6 @@ class StateStoreTestCase(HomeserverTestCase):
         self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
 
     def test_get_state_for_event(self):
-
         # this defaults to a linear DAG as each new injection defaults to whatever
         # forward extremities are currently in the DB for this room.
         e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
@@ -483,3 +482,513 @@ class StateStoreTestCase(HomeserverTestCase):
 
         self.assertEqual(is_all, True)
         self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
+
+
+class StateFilterDifferenceTestCase(TestCase):
+    def assert_difference(
+        self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
+    ):
+        self.assertEqual(
+            minuend.approx_difference(subtrahend),
+            expected,
+            f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
+        )
+
+    def test_state_filter_difference_no_include_other_minus_no_include_other(self):
+        """
+        Tests the StateFilter.approx_difference method
+        where, in a.approx_difference(b), both a and b do not have the
+        include_others flag set.
+        """
+        # (wildcard on state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.Create: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+                include_others=False,
+            ),
+            StateFilter.freeze({EventTypes.Create: None}, include_others=False),
+        )
+
+        # (wildcard on state keys) - (specific state keys)
+        # This one is an over-approximation because we can't represent
+        # 'all state keys except a few named examples'
+        self.assert_difference(
+            StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+            StateFilter.freeze(
+                {EventTypes.Member: {"@wombat:spqr"}},
+                include_others=False,
+            ),
+            StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+        )
+
+        # (wildcard on state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {EventTypes.CanonicalAlias: {""}},
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (specific state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr"},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+        )
+
+    def test_state_filter_difference_include_other_minus_no_include_other(self):
+        """
+        Tests the StateFilter.approx_difference method
+        where, in a.approx_difference(b), only a has the include_others flag set.
+        """
+        # (wildcard on state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.Create: None},
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Create: None,
+                    EventTypes.Member: set(),
+                    EventTypes.CanonicalAlias: set(),
+                },
+                include_others=True,
+            ),
+        )
+
+        # (wildcard on state keys) - (specific state keys)
+        # This one is an over-approximation because we can't represent
+        # 'all state keys except a few named examples'
+        # This also shows that the resultant state filter is normalised.
+        self.assert_difference(
+            StateFilter.freeze({EventTypes.Member: None}, include_others=True),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr"},
+                    EventTypes.Create: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter(types=frozendict(), include_others=True),
+        )
+
+        # (wildcard on state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=False,
+            ),
+            StateFilter(
+                types=frozendict(),
+                include_others=True,
+            ),
+        )
+
+        # (specific state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.CanonicalAlias: {""},
+                    EventTypes.Member: set(),
+                },
+                include_others=True,
+            ),
+        )
+
+        # (specific state keys) - (specific state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr"},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+        )
+
+        # (specific state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+        )
+
+    def test_state_filter_difference_include_other_minus_include_other(self):
+        """
+        Tests the StateFilter.approx_difference method
+        where, in a.approx_difference(b), both a and b have the include_others
+        flag set.
+        """
+        # (wildcard on state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.Create: None},
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+                include_others=True,
+            ),
+            StateFilter(types=frozendict(), include_others=False),
+        )
+
+        # (wildcard on state keys) - (specific state keys)
+        # This one is an over-approximation because we can't represent
+        # 'all state keys except a few named examples'
+        self.assert_difference(
+            StateFilter.freeze({EventTypes.Member: None}, include_others=True),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+                include_others=False,
+            ),
+        )
+
+        # (wildcard on state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=True,
+            ),
+            StateFilter(
+                types=frozendict(),
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (specific state keys)
+        # This one is an over-approximation because we can't represent
+        # 'all state keys except a few named examples'
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                    EventTypes.Create: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr"},
+                    EventTypes.Create: set(),
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@spqr:spqr"},
+                    EventTypes.Create: {""},
+                },
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                },
+                include_others=False,
+            ),
+        )
+
+    def test_state_filter_difference_no_include_other_minus_include_other(self):
+        """
+        Tests the StateFilter.approx_difference method
+        where, in a.approx_difference(b), only b has the include_others flag set.
+        """
+        # (wildcard on state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.Create: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+                include_others=True,
+            ),
+            StateFilter(types=frozendict(), include_others=False),
+        )
+
+        # (wildcard on state keys) - (specific state keys)
+        # This one is an over-approximation because we can't represent
+        # 'all state keys except a few named examples'
+        self.assert_difference(
+            StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+            StateFilter.freeze(
+                {EventTypes.Member: {"@wombat:spqr"}},
+                include_others=True,
+            ),
+            StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+        )
+
+        # (wildcard on state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=True,
+            ),
+            StateFilter(
+                types=frozendict(),
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (specific state keys)
+        # This one is an over-approximation because we can't represent
+        # 'all state keys except a few named examples'
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr"},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@spqr:spqr"},
+                },
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                },
+                include_others=False,
+            ),
+        )
+
+    def test_state_filter_difference_simple_cases(self):
+        """
+        Tests some very simple cases of the StateFilter approx_difference,
+        that are not explicitly tested by the more in-depth tests.
+        """
+
+        self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())
+
+        self.assert_difference(
+            StateFilter.all(),
+            StateFilter.none(),
+            StateFilter.all(),
+        )