diff --git a/scripts-dev/debug_state_res.py b/scripts-dev/debug_state_res.py
index 02b8414809..f32c4d8468 100755
--- a/scripts-dev/debug_state_res.py
+++ b/scripts-dev/debug_state_res.py
@@ -3,7 +3,7 @@ import argparse
import logging
import sys
from pprint import pformat
-from typing import Awaitable, Callable, Collection, Optional, Tuple, cast
+from typing import Awaitable, Callable, Collection, Dict, List, Optional, Tuple, cast
from unittest.mock import MagicMock, patch
import dictdiffer
@@ -73,13 +73,15 @@ def node(
event: EventBase, suffix: Optional[str] = None, **kwargs: object
) -> pydot.Node:
if "label" not in kwargs:
- label = f"{event.event_id}\n{event.sender}: {(event.type,event.state_key)}"
+ label = (
+ f"{event.event_id}\n{event.sender}: {(event.type,event.get_state_key())}"
+ )
if event.type == "m.room.member":
label += f" ({event.membership.upper()})"
if suffix:
label += f"\n{suffix}"
kwargs["label"] = label
- type_to_shape = {} # {"m.room.member": "oval"}
+ type_to_shape: Dict[str, str] = {} # {"m.room.member": "oval"}
if event.type in type_to_shape:
kwargs.setdefault("shape", type_to_shape[event.type])
@@ -97,9 +99,10 @@ def edge(source: EventBase, target: EventBase, **kwargs: object) -> pydot.Edge:
async def dump_mainlines(
hs: MockHomeserver,
- starting_event: EventBase,
+ resolve_point: Optional[EventBase],
+ events: Collection[EventBase],
+ extras: Collection[str],
watch_func: Optional[Callable[[EventBase], Awaitable[str]]] = None,
- extras: Collection[EventBase] = (),
) -> None:
"""Visualise the auth DAG above a given `starting_event`.
@@ -123,21 +126,29 @@ async def dump_mainlines(
suffix = await watch_func(event) if watch_func else None
return node(event, suffix, **kwargs)
- graph.add_node(await new_node(starting_event, fillcolor="#6699cc"))
- seen = {starting_event.event_id}
+ seen = set()
+ todo: List[EventBase] = []
- todo = []
- for extra in extras:
- graph.add_node(await new_node(extra, fillcolor="#cc9966"))
- seen.add(extra.event_id)
- todo.append(extra)
+ if resolve_point:
+ graph.add_node(await new_node(resolve_point, fillcolor="#6699cc"))
+ seen.add(resolve_point.event_id)
- for pid in starting_event.prev_event_ids():
- parent = await hs.get_datastores().main.get_event(pid)
+ for parent in events:
graph.add_node(await new_node(parent, fillcolor="#6699cc"))
- seen.add(pid)
- graph.add_edge(edge(starting_event, parent, style="dashed"))
+ seen.add(parent.event_id)
todo.append(parent)
+ if resolve_point:
+ graph.add_edge(edge(resolve_point, parent, style="dashed"))
+
+ if extras:
+ logger.debug(extras)
+ extra_events = await hs.get_datastores().main.get_events(extras)
+ logger.debug(extra_events)
+ for extra_event in extra_events.values():
+ if extra_event.event_id in seen:
+ continue
+ graph.add_node(await new_node(extra_event, fillcolor="#6699ee"))
+ todo.append(extra_event)
async def fetch_auth_events(event: EventBase) -> StateMap[EventBase]:
return {
@@ -155,6 +166,8 @@ async def dump_mainlines(
(("m.room.power_levels", ""), "solid"),
(("m.room.join_rules", ""), "solid"),
(("m.room.member", event.sender), "dotted"),
+ # TODO: handle that state_key might be missing
+ # (("m.room.member", event.state_key), "solid"),
]:
auth_event = auth_events.get(key)
if auth_event:
@@ -189,13 +202,30 @@ parser.add_argument(
"config_file", help="Synapse config file", type=argparse.FileType("r")
)
parser.add_argument("--verbose", "-v", help="Log verbosely", action="store_true")
+parser.add_argument("-d", "--draw", help="Render auth DAG", action="store_true")
+parser.add_argument(
+ "event_ids",
+ help="""\
+The event ID(s) to be resolved.\
+
+If a single event is given, resolve across all of its parents to compute the state
+before the given event. If multiple events are given, resolve across them directly.
+""",
+ nargs="+",
+)
parser.add_argument(
- "--debug", "-d", help="Enter debugger after state is resolved", action="store_true"
+ "-e",
+ "--extra",
+ dest="extras",
+ help=(
+ "An extra event to include in the auth DAG when using the `--draw` flag. "
+ "Can be provided multiple times."
+ ),
+ action="append",
)
-parser.add_argument("event_id", help="The event ID to be resolved")
parser.add_argument(
"--watch",
- help="Track a piece of state in the auth DAG",
+ help="Track a piece of state in the auth DAG when using the `--draw` flag.",
default=None,
nargs=2,
metavar=("TYPE", "STATE_KEY"),
@@ -213,19 +243,22 @@ async def debug_specific_stateres(
- the recomputed and stored state, written to stdout, and
- their difference, written to stdout.
"""
- # Fetch the event in question.
- event = await hs.get_datastores().main.get_event(args.event_id)
- assert event is not None
- logger.info(
- "event %s has %d parents, %s",
- event.event_id,
- len(event.prev_event_ids()),
- event.prev_event_ids(),
- )
+ DEBUG_AT_EVENT = len(args.event_ids) == 1
+ if DEBUG_AT_EVENT:
+ resolve_point = await hs.get_datastores().main.get_event(args.event_ids[0])
+ prev_event_ids = resolve_point.prev_event_ids()
+ else:
+ resolve_point = None
+ prev_event_ids = args.event_ids
+
+ parent_events = (await hs.get_datastores().main.get_events(prev_event_ids)).values()
+ sample_event = next(iter(parent_events))
+
+ logger.info("Resolving across %d parents, %s", len(prev_event_ids), prev_event_ids)
state_after_parents = [
await hs.get_storage_controllers().state.get_state_ids_for_event(prev_event_id)
- for prev_event_id in event.prev_event_ids()
+ for prev_event_id in prev_event_ids
]
if args.watch is not None:
@@ -236,8 +269,10 @@ async def debug_specific_stateres(
async def watch_func(event: EventBase) -> str:
try:
- result = await hs.get_storage_controllers().state.get_state_ids_for_event(
- event.event_id, filter
+ result = (
+ await hs.get_storage_controllers().state.get_state_ids_for_event(
+ event.event_id, filter
+ )
)
except RuntimeError:
return f"\n{key_pair}: <Event unavailable :(>"
@@ -247,37 +282,31 @@ async def debug_specific_stateres(
else:
watch_func = None
- await dump_mainlines(hs, event, watch_func)
+ if args.draw:
+ await dump_mainlines(hs, resolve_point, parent_events, args.extras, watch_func)
result = await hs.get_state_resolution_handler().resolve_events_with_store(
- event.room_id,
- event.room_version.identifier,
+ sample_event.room_id,
+ sample_event.room_version.identifier,
state_after_parents,
event_map=None,
state_res_store=StateResolutionStore(hs.get_datastores().main),
)
- logger.info("State resolved at %s:", event.event_id)
+ logger.info("State resolved:")
logger.info(pformat(result))
- logger.info("Stored state at %s:", event.event_id)
- stored_state = await hs.get_storage_controllers().state.get_state_ids_for_event(
- event.event_id
- )
- logger.info(pformat(stored_state))
-
- # TODO make this a like-for-like comparison.
- logger.info("Diff from stored (after event) to resolved (before event):")
- for change in dictdiffer.diff(stored_state, result):
- logger.info(pformat(change))
-
- if args.debug:
- print(
- f"see `state_after_parents[i]` for 0 <= i < {len(state_after_parents)}"
- " and `result`",
- file=sys.stderr,
+ if DEBUG_AT_EVENT:
+ logger.info("Stored state at %s:", sample_event.event_id)
+ stored_state = await hs.get_storage_controllers().state.get_state_ids_for_event(
+ sample_event.event_id
)
- breakpoint()
+ logger.info(pformat(stored_state))
+
+ # TODO make this a like-for-like comparison.
+ logger.info("Diff from stored (after event) to resolved (before event):")
+ for change in dictdiffer.diff(stored_state, result):
+ logger.info(pformat(change))
# Entrypoint.
@@ -288,7 +317,7 @@ if __name__ == "__main__":
level=logging.DEBUG if args.verbose else logging.INFO,
stream=sys.stdout,
)
- # Suppress logs weren't not interested in.
+ # Suppress logs we aren't interested in.
logging.getLogger("synapse.util").setLevel(logging.ERROR)
logging.getLogger("synapse.storage").setLevel(logging.ERROR)
|