diff options
author | David Robertson <davidr@element.io> | 2022-08-04 20:12:00 +0100 |
---|---|---|
committer | David Robertson <davidr@element.io> | 2022-08-04 20:12:00 +0100 |
commit | 7c9388f19ed52aaeba6ddc2ddf433f9b16e95b3c (patch) | |
tree | 4a3360979631bf4a85d9a4d28b78ca3022f98645 | |
parent | Handle events that we don't have the state for(?) (diff) | |
download | synapse-7c9388f19ed52aaeba6ddc2ddf433f9b16e95b3c.tar.xz |
Draw: readd --extras, arbitrary resolutions github/dmr/stateres/debug dmr/stateres/debug
-rwxr-xr-x | scripts-dev/debug_state_res.py | 135 |
1 files changed, 82 insertions, 53 deletions
diff --git a/scripts-dev/debug_state_res.py b/scripts-dev/debug_state_res.py index 02b8414809..f32c4d8468 100755 --- a/scripts-dev/debug_state_res.py +++ b/scripts-dev/debug_state_res.py @@ -3,7 +3,7 @@ import argparse import logging import sys from pprint import pformat -from typing import Awaitable, Callable, Collection, Optional, Tuple, cast +from typing import Awaitable, Callable, Collection, Dict, List, Optional, Tuple, cast from unittest.mock import MagicMock, patch import dictdiffer @@ -73,13 +73,15 @@ def node( event: EventBase, suffix: Optional[str] = None, **kwargs: object ) -> pydot.Node: if "label" not in kwargs: - label = f"{event.event_id}\n{event.sender}: {(event.type,event.state_key)}" + label = ( + f"{event.event_id}\n{event.sender}: {(event.type,event.get_state_key())}" + ) if event.type == "m.room.member": label += f" ({event.membership.upper()})" if suffix: label += f"\n{suffix}" kwargs["label"] = label - type_to_shape = {} # {"m.room.member": "oval"} + type_to_shape: Dict[str, str] = {} # {"m.room.member": "oval"} if event.type in type_to_shape: kwargs.setdefault("shape", type_to_shape[event.type]) @@ -97,9 +99,10 @@ def edge(source: EventBase, target: EventBase, **kwargs: object) -> pydot.Edge: async def dump_mainlines( hs: MockHomeserver, - starting_event: EventBase, + resolve_point: Optional[EventBase], + events: Collection[EventBase], + extras: Collection[str], watch_func: Optional[Callable[[EventBase], Awaitable[str]]] = None, - extras: Collection[EventBase] = (), ) -> None: """Visualise the auth DAG above a given `starting_event`. @@ -123,21 +126,29 @@ async def dump_mainlines( suffix = await watch_func(event) if watch_func else None return node(event, suffix, **kwargs) - graph.add_node(await new_node(starting_event, fillcolor="#6699cc")) - seen = {starting_event.event_id} + seen = set() + todo: List[EventBase] = [] - todo = [] - for extra in extras: - graph.add_node(await new_node(extra, fillcolor="#cc9966")) - seen.add(extra.event_id) - todo.append(extra) + if resolve_point: + graph.add_node(await new_node(resolve_point, fillcolor="#6699cc")) + seen.add(resolve_point.event_id) - for pid in starting_event.prev_event_ids(): - parent = await hs.get_datastores().main.get_event(pid) + for parent in events: graph.add_node(await new_node(parent, fillcolor="#6699cc")) - seen.add(pid) - graph.add_edge(edge(starting_event, parent, style="dashed")) + seen.add(parent.event_id) todo.append(parent) + if resolve_point: + graph.add_edge(edge(resolve_point, parent, style="dashed")) + + if extras: + logger.debug(extras) + extra_events = await hs.get_datastores().main.get_events(extras) + logger.debug(extra_events) + for extra_event in extra_events.values(): + if extra_event.event_id in seen: + continue + graph.add_node(await new_node(extra_event, fillcolor="#6699ee")) + todo.append(extra_event) async def fetch_auth_events(event: EventBase) -> StateMap[EventBase]: return { @@ -155,6 +166,8 @@ async def dump_mainlines( (("m.room.power_levels", ""), "solid"), (("m.room.join_rules", ""), "solid"), (("m.room.member", event.sender), "dotted"), + # TODO: handle that state_key might be missing + # (("m.room.member", event.state_key), "solid"), ]: auth_event = auth_events.get(key) if auth_event: @@ -189,13 +202,30 @@ parser.add_argument( "config_file", help="Synapse config file", type=argparse.FileType("r") ) parser.add_argument("--verbose", "-v", help="Log verbosely", action="store_true") +parser.add_argument("-d", "--draw", help="Render auth DAG", action="store_true") +parser.add_argument( + "event_ids", + help="""\ +The event ID(s) to be resolved.\ + +If a single event is given, resolve across all of its parents to compute the state +before the given event. If multiple events are given, resolve across them directly. +""", + nargs="+", +) parser.add_argument( - "--debug", "-d", help="Enter debugger after state is resolved", action="store_true" + "-e", + "--extra", + dest="extras", + help=( + "An extra event to include in the auth DAG when using the `--draw` flag. " + "Can be provided multiple times." + ), + action="append", ) -parser.add_argument("event_id", help="The event ID to be resolved") parser.add_argument( "--watch", - help="Track a piece of state in the auth DAG", + help="Track a piece of state in the auth DAG when using the `--draw` flag.", default=None, nargs=2, metavar=("TYPE", "STATE_KEY"), @@ -213,19 +243,22 @@ async def debug_specific_stateres( - the recomputed and stored state, written to stdout, and - their difference, written to stdout. """ - # Fetch the event in question. - event = await hs.get_datastores().main.get_event(args.event_id) - assert event is not None - logger.info( - "event %s has %d parents, %s", - event.event_id, - len(event.prev_event_ids()), - event.prev_event_ids(), - ) + DEBUG_AT_EVENT = len(args.event_ids) == 1 + if DEBUG_AT_EVENT: + resolve_point = await hs.get_datastores().main.get_event(args.event_ids[0]) + prev_event_ids = resolve_point.prev_event_ids() + else: + resolve_point = None + prev_event_ids = args.event_ids + + parent_events = (await hs.get_datastores().main.get_events(prev_event_ids)).values() + sample_event = next(iter(parent_events)) + + logger.info("Resolving across %d parents, %s", len(prev_event_ids), prev_event_ids) state_after_parents = [ await hs.get_storage_controllers().state.get_state_ids_for_event(prev_event_id) - for prev_event_id in event.prev_event_ids() + for prev_event_id in prev_event_ids ] if args.watch is not None: @@ -236,8 +269,10 @@ async def debug_specific_stateres( async def watch_func(event: EventBase) -> str: try: - result = await hs.get_storage_controllers().state.get_state_ids_for_event( - event.event_id, filter + result = ( + await hs.get_storage_controllers().state.get_state_ids_for_event( + event.event_id, filter + ) ) except RuntimeError: return f"\n{key_pair}: <Event unavailable :(>" @@ -247,37 +282,31 @@ async def debug_specific_stateres( else: watch_func = None - await dump_mainlines(hs, event, watch_func) + if args.draw: + await dump_mainlines(hs, resolve_point, parent_events, args.extras, watch_func) result = await hs.get_state_resolution_handler().resolve_events_with_store( - event.room_id, - event.room_version.identifier, + sample_event.room_id, + sample_event.room_version.identifier, state_after_parents, event_map=None, state_res_store=StateResolutionStore(hs.get_datastores().main), ) - logger.info("State resolved at %s:", event.event_id) + logger.info("State resolved:") logger.info(pformat(result)) - logger.info("Stored state at %s:", event.event_id) - stored_state = await hs.get_storage_controllers().state.get_state_ids_for_event( - event.event_id - ) - logger.info(pformat(stored_state)) - - # TODO make this a like-for-like comparison. - logger.info("Diff from stored (after event) to resolved (before event):") - for change in dictdiffer.diff(stored_state, result): - logger.info(pformat(change)) - - if args.debug: - print( - f"see `state_after_parents[i]` for 0 <= i < {len(state_after_parents)}" - " and `result`", - file=sys.stderr, + if DEBUG_AT_EVENT: + logger.info("Stored state at %s:", sample_event.event_id) + stored_state = await hs.get_storage_controllers().state.get_state_ids_for_event( + sample_event.event_id ) - breakpoint() + logger.info(pformat(stored_state)) + + # TODO make this a like-for-like comparison. + logger.info("Diff from stored (after event) to resolved (before event):") + for change in dictdiffer.diff(stored_state, result): + logger.info(pformat(change)) # Entrypoint. @@ -288,7 +317,7 @@ if __name__ == "__main__": level=logging.DEBUG if args.verbose else logging.INFO, stream=sys.stdout, ) - # Suppress logs weren't not interested in. + # Suppress logs we aren't interested in. logging.getLogger("synapse.util").setLevel(logging.ERROR) logging.getLogger("synapse.storage").setLevel(logging.ERROR) |