summary refs log tree commit diff
path: root/scripts-dev/debug_state_res.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts-dev/debug_state_res.py')
-rwxr-xr-xscripts-dev/debug_state_res.py135
1 files changed, 82 insertions, 53 deletions
diff --git a/scripts-dev/debug_state_res.py b/scripts-dev/debug_state_res.py
index 02b8414809..f32c4d8468 100755
--- a/scripts-dev/debug_state_res.py
+++ b/scripts-dev/debug_state_res.py
@@ -3,7 +3,7 @@ import argparse
 import logging
 import sys
 from pprint import pformat
-from typing import Awaitable, Callable, Collection, Optional, Tuple, cast
+from typing import Awaitable, Callable, Collection, Dict, List, Optional, Tuple, cast
 from unittest.mock import MagicMock, patch
 
 import dictdiffer
@@ -73,13 +73,15 @@ def node(
     event: EventBase, suffix: Optional[str] = None, **kwargs: object
 ) -> pydot.Node:
     if "label" not in kwargs:
-        label = f"{event.event_id}\n{event.sender}: {(event.type,event.state_key)}"
+        label = (
+            f"{event.event_id}\n{event.sender}: {(event.type,event.get_state_key())}"
+        )
         if event.type == "m.room.member":
             label += f" ({event.membership.upper()})"
         if suffix:
             label += f"\n{suffix}"
         kwargs["label"] = label
-    type_to_shape = {}  # {"m.room.member": "oval"}
+    type_to_shape: Dict[str, str] = {}  # {"m.room.member": "oval"}
     if event.type in type_to_shape:
         kwargs.setdefault("shape", type_to_shape[event.type])
 
@@ -97,9 +99,10 @@ def edge(source: EventBase, target: EventBase, **kwargs: object) -> pydot.Edge:
 
 async def dump_mainlines(
     hs: MockHomeserver,
-    starting_event: EventBase,
+    resolve_point: Optional[EventBase],
+    events: Collection[EventBase],
+    extras: Collection[str],
     watch_func: Optional[Callable[[EventBase], Awaitable[str]]] = None,
-    extras: Collection[EventBase] = (),
 ) -> None:
     """Visualise the auth DAG above a given `starting_event`.
 
@@ -123,21 +126,29 @@ async def dump_mainlines(
         suffix = await watch_func(event) if watch_func else None
         return node(event, suffix, **kwargs)
 
-    graph.add_node(await new_node(starting_event, fillcolor="#6699cc"))
-    seen = {starting_event.event_id}
+    seen = set()
+    todo: List[EventBase] = []
 
-    todo = []
-    for extra in extras:
-        graph.add_node(await new_node(extra, fillcolor="#cc9966"))
-        seen.add(extra.event_id)
-        todo.append(extra)
+    if resolve_point:
+        graph.add_node(await new_node(resolve_point, fillcolor="#6699cc"))
+        seen.add(resolve_point.event_id)
 
-    for pid in starting_event.prev_event_ids():
-        parent = await hs.get_datastores().main.get_event(pid)
+    for parent in events:
         graph.add_node(await new_node(parent, fillcolor="#6699cc"))
-        seen.add(pid)
-        graph.add_edge(edge(starting_event, parent, style="dashed"))
+        seen.add(parent.event_id)
         todo.append(parent)
+        if resolve_point:
+            graph.add_edge(edge(resolve_point, parent, style="dashed"))
+
+    if extras:
+        logger.debug(extras)
+        extra_events = await hs.get_datastores().main.get_events(extras)
+        logger.debug(extra_events)
+        for extra_event in extra_events.values():
+            if extra_event.event_id in seen:
+                continue
+            graph.add_node(await new_node(extra_event, fillcolor="#6699ee"))
+            todo.append(extra_event)
 
     async def fetch_auth_events(event: EventBase) -> StateMap[EventBase]:
         return {
@@ -155,6 +166,8 @@ async def dump_mainlines(
             (("m.room.power_levels", ""), "solid"),
             (("m.room.join_rules", ""), "solid"),
             (("m.room.member", event.sender), "dotted"),
+            # TODO: handle that state_key might be missing
+            # (("m.room.member", event.state_key), "solid"),
         ]:
             auth_event = auth_events.get(key)
             if auth_event:
@@ -189,13 +202,30 @@ parser.add_argument(
     "config_file", help="Synapse config file", type=argparse.FileType("r")
 )
 parser.add_argument("--verbose", "-v", help="Log verbosely", action="store_true")
+parser.add_argument("-d", "--draw", help="Render auth DAG", action="store_true")
+parser.add_argument(
+    "event_ids",
+    help="""\
+The event ID(s) to be resolved.\
+
+If a single event is given, resolve across all of its parents to compute the state
+before the given event. If multiple events are given, resolve across them directly.
+""",
+    nargs="+",
+)
 parser.add_argument(
-    "--debug", "-d", help="Enter debugger after state is resolved", action="store_true"
+    "-e",
+    "--extra",
+    dest="extras",
+    help=(
+        "An extra event to include in the auth DAG when using the `--draw` flag. "
+        "Can be provided multiple times."
+    ),
+    action="append",
 )
-parser.add_argument("event_id", help="The event ID to be resolved")
 parser.add_argument(
     "--watch",
-    help="Track a piece of state in the auth DAG",
+    help="Track a piece of state in the auth DAG when using the `--draw` flag.",
     default=None,
     nargs=2,
     metavar=("TYPE", "STATE_KEY"),
@@ -213,19 +243,22 @@ async def debug_specific_stateres(
     - the recomputed and stored state, written to stdout, and
     - their difference, written to stdout.
     """
-    # Fetch the event in question.
-    event = await hs.get_datastores().main.get_event(args.event_id)
-    assert event is not None
-    logger.info(
-        "event %s has %d parents, %s",
-        event.event_id,
-        len(event.prev_event_ids()),
-        event.prev_event_ids(),
-    )
+    DEBUG_AT_EVENT = len(args.event_ids) == 1
 
+    if DEBUG_AT_EVENT:
+        resolve_point = await hs.get_datastores().main.get_event(args.event_ids[0])
+        prev_event_ids = resolve_point.prev_event_ids()
+    else:
+        resolve_point = None
+        prev_event_ids = args.event_ids
+
+    parent_events = (await hs.get_datastores().main.get_events(prev_event_ids)).values()
+    sample_event = next(iter(parent_events))
+
+    logger.info("Resolving across %d parents, %s", len(prev_event_ids), prev_event_ids)
     state_after_parents = [
         await hs.get_storage_controllers().state.get_state_ids_for_event(prev_event_id)
-        for prev_event_id in event.prev_event_ids()
+        for prev_event_id in prev_event_ids
     ]
 
     if args.watch is not None:
@@ -236,8 +269,10 @@ async def debug_specific_stateres(
 
         async def watch_func(event: EventBase) -> str:
             try:
-                result = await hs.get_storage_controllers().state.get_state_ids_for_event(
-                    event.event_id, filter
+                result = (
+                    await hs.get_storage_controllers().state.get_state_ids_for_event(
+                        event.event_id, filter
+                    )
                 )
             except RuntimeError:
                 return f"\n{key_pair}: <Event unavailable :(>"
@@ -247,37 +282,31 @@ async def debug_specific_stateres(
     else:
         watch_func = None
 
-    await dump_mainlines(hs, event, watch_func)
+    if args.draw:
+        await dump_mainlines(hs, resolve_point, parent_events, args.extras, watch_func)
 
     result = await hs.get_state_resolution_handler().resolve_events_with_store(
-        event.room_id,
-        event.room_version.identifier,
+        sample_event.room_id,
+        sample_event.room_version.identifier,
         state_after_parents,
         event_map=None,
         state_res_store=StateResolutionStore(hs.get_datastores().main),
     )
 
-    logger.info("State resolved at %s:", event.event_id)
+    logger.info("State resolved:")
     logger.info(pformat(result))
 
-    logger.info("Stored state at %s:", event.event_id)
-    stored_state = await hs.get_storage_controllers().state.get_state_ids_for_event(
-        event.event_id
-    )
-    logger.info(pformat(stored_state))
-
-    # TODO make this a like-for-like comparison.
-    logger.info("Diff from stored (after event) to resolved (before event):")
-    for change in dictdiffer.diff(stored_state, result):
-        logger.info(pformat(change))
-
-    if args.debug:
-        print(
-            f"see `state_after_parents[i]` for 0 <= i < {len(state_after_parents)}"
-            " and `result`",
-            file=sys.stderr,
+    if DEBUG_AT_EVENT:
+        logger.info("Stored state at %s:", sample_event.event_id)
+        stored_state = await hs.get_storage_controllers().state.get_state_ids_for_event(
+            sample_event.event_id
         )
-        breakpoint()
+        logger.info(pformat(stored_state))
+
+        # TODO make this a like-for-like comparison.
+        logger.info("Diff from stored (after event) to resolved (before event):")
+        for change in dictdiffer.diff(stored_state, result):
+            logger.info(pformat(change))
 
 
 # Entrypoint.
@@ -288,7 +317,7 @@ if __name__ == "__main__":
         level=logging.DEBUG if args.verbose else logging.INFO,
         stream=sys.stdout,
     )
-    # Suppress logs weren't not interested in.
+    # Suppress logs we aren't interested in.
     logging.getLogger("synapse.util").setLevel(logging.ERROR)
     logging.getLogger("synapse.storage").setLevel(logging.ERROR)