summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2023-08-18 15:32:06 +0100
committerGitHub <noreply@github.com>2023-08-18 15:32:06 +0100
commitbd558a6dc369b6f5d06ab6fd2500faa216a45883 (patch)
tree8cd5827573d1537d23b62ed152d52b2840d6e854 /tests/storage
parentMSC3861: allow impersonation by an admin using a query param (#16132) (diff)
downloadsynapse-bd558a6dc369b6f5d06ab6fd2500faa216a45883.tar.xz
Speed up state res in rare case we don't have all events (#16116)
If we don't have all the auth events in a room then not all state events will have a chain cover index. Even so, we can still use the chain cover index on the events that do have it, rather than bailing and using the slower functions.

This situation should not arise for newly persisted rooms, as we check we have the full auth chain for each event, but can happen for existing rooms.

c.f. #15245
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/test_event_federation.py241
1 files changed, 193 insertions, 48 deletions
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 9c151a5e62..7a4ecab2d5 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -13,7 +13,19 @@
 # limitations under the License.
 
 import datetime
-from typing import Dict, List, Tuple, Union, cast
+from typing import (
+    Collection,
+    Dict,
+    FrozenSet,
+    Iterable,
+    List,
+    Mapping,
+    Set,
+    Tuple,
+    TypeVar,
+    Union,
+    cast,
+)
 
 import attr
 from parameterized import parameterized
@@ -38,6 +50,138 @@ from synapse.util import Clock, json_encoder
 import tests.unittest
 import tests.utils
 
+# The silly auth graph we use to test the auth difference algorithm,
+# where the top are the most recent events.
+#
+#   A   B
+#    \ /
+#  D  E
+#  \  |
+#   ` F   C
+#     |  /|
+#     G ´ |
+#     | \ |
+#     H   I
+#     |   |
+#     K   J
+
+AUTH_GRAPH: Dict[str, List[str]] = {
+    "a": ["e"],
+    "b": ["e"],
+    "c": ["g", "i"],
+    "d": ["f"],
+    "e": ["f"],
+    "f": ["g"],
+    "g": ["h", "i"],
+    "h": ["k"],
+    "i": ["j"],
+    "k": [],
+    "j": [],
+}
+
+DEPTH_GRAPH = {
+    "a": 7,
+    "b": 7,
+    "c": 4,
+    "d": 6,
+    "e": 6,
+    "f": 5,
+    "g": 3,
+    "h": 2,
+    "i": 2,
+    "k": 1,
+    "j": 1,
+}
+
+T = TypeVar("T")
+
+
+def get_all_topologically_sorted_orders(
+    nodes: Iterable[T],
+    graph: Mapping[T, Collection[T]],
+) -> List[List[T]]:
+    """Given a set of nodes and a graph, return all possible topological
+    orderings.
+    """
+
+    # This is implemented by Kahn's algorithm, and forking execution each time
+    # we have a choice over which node to consider next.
+
+    degree_map = {node: 0 for node in nodes}
+    reverse_graph: Dict[T, Set[T]] = {}
+
+    for node, edges in graph.items():
+        if node not in degree_map:
+            continue
+
+        for edge in set(edges):
+            if edge in degree_map:
+                degree_map[node] += 1
+
+            reverse_graph.setdefault(edge, set()).add(node)
+        reverse_graph.setdefault(node, set())
+
+    zero_degree = [node for node, degree in degree_map.items() if degree == 0]
+
+    return _get_all_topologically_sorted_orders_inner(
+        reverse_graph, zero_degree, degree_map
+    )
+
+
+def _get_all_topologically_sorted_orders_inner(
+    reverse_graph: Dict[T, Set[T]],
+    zero_degree: List[T],
+    degree_map: Dict[T, int],
+) -> List[List[T]]:
+    new_paths = []
+
+    # Rather than only choosing *one* item from the list of nodes with zero
+    # degree, we "fork" execution and run the algorithm for each node in the
+    # zero degree.
+    for node in zero_degree:
+        new_degree_map = degree_map.copy()
+        new_zero_degree = zero_degree.copy()
+        new_zero_degree.remove(node)
+
+        for edge in reverse_graph.get(node, []):
+            if edge in new_degree_map:
+                new_degree_map[edge] -= 1
+                if new_degree_map[edge] == 0:
+                    new_zero_degree.append(edge)
+
+        paths = _get_all_topologically_sorted_orders_inner(
+            reverse_graph, new_zero_degree, new_degree_map
+        )
+        for path in paths:
+            path.insert(0, node)
+
+        new_paths.extend(paths)
+
+    if not new_paths:
+        return [[]]
+
+    return new_paths
+
+
+def get_all_topologically_consistent_subsets(
+    nodes: Iterable[T],
+    graph: Mapping[T, Collection[T]],
+) -> Set[FrozenSet[T]]:
+    """Get all subsets of the graph where if node N is in the subgraph, then all
+    nodes that can reach that node (i.e. for all X there exists a path X -> N)
+    are in the subgraph.
+    """
+    all_topological_orderings = get_all_topologically_sorted_orders(nodes, graph)
+
+    graph_subsets = set()
+    for ordering in all_topological_orderings:
+        ordering.reverse()
+
+        for idx in range(len(ordering)):
+            graph_subsets.add(frozenset(ordering[:idx]))
+
+    return graph_subsets
+
 
 @attr.s(auto_attribs=True, frozen=True, slots=True)
 class _BackfillSetupInfo:
@@ -172,49 +316,6 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
     def _setup_auth_chain(self, use_chain_cover_index: bool) -> str:
         room_id = "@ROOM:local"
 
-        # The silly auth graph we use to test the auth difference algorithm,
-        # where the top are the most recent events.
-        #
-        #   A   B
-        #    \ /
-        #  D  E
-        #  \  |
-        #   ` F   C
-        #     |  /|
-        #     G ´ |
-        #     | \ |
-        #     H   I
-        #     |   |
-        #     K   J
-
-        auth_graph: Dict[str, List[str]] = {
-            "a": ["e"],
-            "b": ["e"],
-            "c": ["g", "i"],
-            "d": ["f"],
-            "e": ["f"],
-            "f": ["g"],
-            "g": ["h", "i"],
-            "h": ["k"],
-            "i": ["j"],
-            "k": [],
-            "j": [],
-        }
-
-        depth_map = {
-            "a": 7,
-            "b": 7,
-            "c": 4,
-            "d": 6,
-            "e": 6,
-            "f": 5,
-            "g": 3,
-            "h": 2,
-            "i": 2,
-            "k": 1,
-            "j": 1,
-        }
-
         # Mark the room as maybe having a cover index.
 
         def store_room(txn: LoggingTransaction) -> None:
@@ -238,9 +339,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         def insert_event(txn: LoggingTransaction) -> None:
             stream_ordering = 0
 
-            for event_id in auth_graph:
+            for event_id in AUTH_GRAPH:
                 stream_ordering += 1
-                depth = depth_map[event_id]
+                depth = DEPTH_GRAPH[event_id]
 
                 self.store.db_pool.simple_insert_txn(
                     txn,
@@ -260,8 +361,8 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             self.persist_events._persist_event_auth_chain_txn(
                 txn,
                 [
-                    cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
-                    for event_id in auth_graph
+                    cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
+                    for event_id in AUTH_GRAPH
                 ],
             )
 
@@ -344,7 +445,51 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         room_id = self._setup_auth_chain(use_chain_cover_index)
 
         # Now actually test that various combinations give the right result:
+        self.assert_auth_diff_is_expected(room_id)
+
+    @parameterized.expand(
+        [
+            [graph_subset]
+            for graph_subset in get_all_topologically_consistent_subsets(
+                AUTH_GRAPH, AUTH_GRAPH
+            )
+        ]
+    )
+    def test_auth_difference_partial(self, graph_subset: Collection[str]) -> None:
+        """Test that if we only have a chain cover index on a partial subset of
+        the room we still get the correct auth chain difference.
+
+        We do this by removing the chain cover index for every valid subset of the
+        graph.
+        """
+        room_id = self._setup_auth_chain(True)
+
+        for event_id in graph_subset:
+            # Remove chain cover from that event.
+            self.get_success(
+                self.store.db_pool.simple_delete(
+                    table="event_auth_chains",
+                    keyvalues={"event_id": event_id},
+                    desc="test_auth_difference_partial_remove",
+                )
+            )
+            self.get_success(
+                self.store.db_pool.simple_insert(
+                    table="event_auth_chain_to_calculate",
+                    values={
+                        "event_id": event_id,
+                        "room_id": room_id,
+                        "type": "",
+                        "state_key": "",
+                    },
+                    desc="test_auth_difference_partial_remove",
+                )
+            )
+
+        self.assert_auth_diff_is_expected(room_id)
 
+    def assert_auth_diff_is_expected(self, room_id: str) -> None:
+        """Assert the auth chain difference returns the correct answers."""
         difference = self.get_success(
             self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
         )