summary refs log tree commit diff
path: root/tests/storage/test_event_federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test_event_federation.py')
-rw-r--r--tests/storage/test_event_federation.py249
1 files changed, 225 insertions, 24 deletions
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 482506d731..9d04a066d8 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -13,6 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import attr
+from parameterized import parameterized
+
+from synapse.events import _EventInternalMetadata
+
 import tests.unittest
 import tests.utils
 
@@ -113,7 +118,8 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
         self.assertTrue(r == [room2] or r == [room3])
 
-    def test_auth_difference(self):
+    @parameterized.expand([(True,), (False,)])
+    def test_auth_difference(self, use_chain_cover_index: bool):
         room_id = "@ROOM:local"
 
         # The silly auth graph we use to test the auth difference algorithm,
@@ -159,46 +165,223 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             "j": 1,
         }
 
+        # Mark the room as not having a cover index
+
+        def store_room(txn):
+            self.store.db_pool.simple_insert_txn(
+                txn,
+                "rooms",
+                {
+                    "room_id": room_id,
+                    "creator": "room_creator_user_id",
+                    "is_public": True,
+                    "room_version": "6",
+                    "has_auth_chain_index": use_chain_cover_index,
+                },
+            )
+
+        self.get_success(self.store.db_pool.runInteraction("store_room", store_room))
+
         # We rudely fiddle with the appropriate tables directly, as that's much
         # easier than constructing events properly.
 
-        def insert_event(txn, event_id, stream_ordering):
+        def insert_event(txn):
+            stream_ordering = 0
+
+            for event_id in auth_graph:
+                stream_ordering += 1
+                depth = depth_map[event_id]
+
+                self.store.db_pool.simple_insert_txn(
+                    txn,
+                    table="events",
+                    values={
+                        "event_id": event_id,
+                        "room_id": room_id,
+                        "depth": depth,
+                        "topological_ordering": depth,
+                        "type": "m.test",
+                        "processed": True,
+                        "outlier": False,
+                        "stream_ordering": stream_ordering,
+                    },
+                )
+
+            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+                txn,
+                [
+                    FakeEvent(event_id, room_id, auth_graph[event_id])
+                    for event_id in auth_graph
+                ],
+            )
+
+        self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
+
+        # Now actually test that various combinations give the right result:
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "c"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "d", "e"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}])
+        )
+        self.assertSetEqual(difference, set())
+
+    def test_auth_difference_partial_cover(self):
+        """Test that we correctly handle rooms where not all events have a chain
+        cover calculated. This can happen in some obscure edge cases, including
+        during the background update that calculates the chain cover for old
+        rooms.
+        """
+
+        room_id = "@ROOM:local"
+
+        # The silly auth graph we use to test the auth difference algorithm,
+        # where the top are the most recent events.
+        #
+        #   A   B
+        #    \ /
+        #  D  E
+        #  \  |
+        #   ` F   C
+        #     |  /|
+        #     G ยด |
+        #     | \ |
+        #     H   I
+        #     |   |
+        #     K   J
+
+        auth_graph = {
+            "a": ["e"],
+            "b": ["e"],
+            "c": ["g", "i"],
+            "d": ["f"],
+            "e": ["f"],
+            "f": ["g"],
+            "g": ["h", "i"],
+            "h": ["k"],
+            "i": ["j"],
+            "k": [],
+            "j": [],
+        }
+
+        depth_map = {
+            "a": 7,
+            "b": 7,
+            "c": 4,
+            "d": 6,
+            "e": 6,
+            "f": 5,
+            "g": 3,
+            "h": 2,
+            "i": 2,
+            "k": 1,
+            "j": 1,
+        }
 
-            depth = depth_map[event_id]
+        # We rudely fiddle with the appropriate tables directly, as that's much
+        # easier than constructing events properly.
 
+        def insert_event(txn):
+            # First insert the room and mark it as having a chain cover.
             self.store.db_pool.simple_insert_txn(
                 txn,
-                table="events",
-                values={
-                    "event_id": event_id,
+                "rooms",
+                {
                     "room_id": room_id,
-                    "depth": depth,
-                    "topological_ordering": depth,
-                    "type": "m.test",
-                    "processed": True,
-                    "outlier": False,
-                    "stream_ordering": stream_ordering,
+                    "creator": "room_creator_user_id",
+                    "is_public": True,
+                    "room_version": "6",
+                    "has_auth_chain_index": True,
                 },
             )
 
-            self.store.db_pool.simple_insert_many_txn(
+            stream_ordering = 0
+
+            for event_id in auth_graph:
+                stream_ordering += 1
+                depth = depth_map[event_id]
+
+                self.store.db_pool.simple_insert_txn(
+                    txn,
+                    table="events",
+                    values={
+                        "event_id": event_id,
+                        "room_id": room_id,
+                        "depth": depth,
+                        "topological_ordering": depth,
+                        "type": "m.test",
+                        "processed": True,
+                        "outlier": False,
+                        "stream_ordering": stream_ordering,
+                    },
+                )
+
+            # Insert all events apart from 'B'
+            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
                 txn,
-                table="event_auth",
-                values=[
-                    {"event_id": event_id, "room_id": room_id, "auth_id": a}
-                    for a in auth_graph[event_id]
+                [
+                    FakeEvent(event_id, room_id, auth_graph[event_id])
+                    for event_id in auth_graph
+                    if event_id != "b"
                 ],
             )
 
-        next_stream_ordering = 0
-        for event_id in auth_graph:
-            next_stream_ordering += 1
-            self.get_success(
-                self.store.db_pool.runInteraction(
-                    "insert", insert_event, event_id, next_stream_ordering
-                )
+            # Now we insert the event 'B' without a chain cover, by temporarily
+            # pretending the room doesn't have a chain cover.
+
+            self.store.db_pool.simple_update_txn(
+                txn,
+                table="rooms",
+                keyvalues={"room_id": room_id},
+                updatevalues={"has_auth_chain_index": False},
+            )
+
+            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+                txn, [FakeEvent("b", room_id, auth_graph["b"])],
+            )
+
+            self.store.db_pool.simple_update_txn(
+                txn,
+                table="rooms",
+                keyvalues={"room_id": room_id},
+                updatevalues={"has_auth_chain_index": True},
             )
 
+        self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
+
         # Now actually test that various combinations give the right result:
 
         difference = self.get_success(
@@ -240,3 +423,21 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             self.store.get_auth_chain_difference(room_id, [{"a"}])
         )
         self.assertSetEqual(difference, set())
+
+
+@attr.s
+class FakeEvent:
+    event_id = attr.ib()
+    room_id = attr.ib()
+    auth_events = attr.ib()
+
+    type = "foo"
+    state_key = "foo"
+
+    internal_metadata = _EventInternalMetadata({})
+
+    def auth_event_ids(self):
+        return self.auth_events
+
+    def is_state(self):
+        return True