From a46db390580d375cedbe443ccd503859dd621d72 Mon Sep 17 00:00:00 2001 From: theodorbadea Date: Thu, 27 Mar 2025 18:47:34 +0000 Subject: [PATCH 1/8] preserve collectives --- src/converter/pytorch_converter.py | 3 ++- src/trace_link/chakra_host_trace_loader.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/converter/pytorch_converter.py b/src/converter/pytorch_converter.py index ea383a51..5eabe96a 100644 --- a/src/converter/pytorch_converter.py +++ b/src/converter/pytorch_converter.py @@ -49,7 +49,8 @@ def convert(self, input_filename: str, output_filename: str, simulate: bool) -> for root_node in root_node_list: self.convert_ctrl_dep_to_data_dep(json_node_map, protobuf_node_map, root_node) - protobuf_node_map = self.remove_dangling_nodes(protobuf_node_map) + # do not remove secondary connected components + # protobuf_node_map = self.remove_dangling_nodes(protobuf_node_map) parent_to_children_map = self.update_parent_to_children_map(protobuf_node_map) diff --git a/src/trace_link/chakra_host_trace_loader.py b/src/trace_link/chakra_host_trace_loader.py index 8b2723b3..4d5898eb 100644 --- a/src/trace_link/chakra_host_trace_loader.py +++ b/src/trace_link/chakra_host_trace_loader.py @@ -25,8 +25,14 @@ def load(self, chakra_host_trace_file: str) -> List[PyTorchOperator]: logging.debug(f"Starting to load Chakra host execution trace from file: {chakra_host_trace_file}.") chakra_host_trace = load_execution_trace_file(chakra_host_trace_file) - root_node = chakra_host_trace.get_nodes()[1] # Root node is usually 1-based - chakra_host_ops = self.extract_chakra_host_ops(root_node) + # root_node = chakra_host_trace.get_nodes()[1] # Root node is usually 1-based + # chakra_host_ops = self.extract_chakra_host_ops(root_node) + + # also include orphaned node belonging secondary connected components + chakra_host_ops = [] + for host_op in chakra_host_trace.get_nodes().values(): + chakra_host_ops.append(host_op) + logging.debug(f"Extracted {len(chakra_host_ops)} operators from Chakra host execution trace.") logging.debug("Chakra host execution trace has been loaded and processed successfully.") From a99a5b3e30b2a6e6825461938e34b15e583df95c Mon Sep 17 00:00:00 2001 From: Theodor Badea Date: Wed, 28 May 2025 17:52:44 +0000 Subject: [PATCH 2/8] link orphans to closest root by thread id --- src/trace_link/chakra_host_trace_loader.py | 92 +++++++++++++--------- src/trace_link/trace_linker.py | 26 +++--- 2 files changed, 68 insertions(+), 50 deletions(-) diff --git a/src/trace_link/chakra_host_trace_loader.py b/src/trace_link/chakra_host_trace_loader.py index 4d5898eb..003fc53f 100644 --- a/src/trace_link/chakra_host_trace_loader.py +++ b/src/trace_link/chakra_host_trace_loader.py @@ -1,9 +1,11 @@ import logging import sys -from typing import List +from typing import Any, Callable, Dict, List, Tuple from et_replay.execution_trace import Node as PyTorchOperator -from et_replay.utils import load_execution_trace_file +from et_replay.execution_trace import ExecutionTrace as PyTorchTrace +from et_replay.execution_trace import EXECUTION_TRACE_THREAD_ANNOTATION +from et_replay.utils import read_dictionary_from_json_file # Increase the recursion limit for deep Chakra host execution traces. sys.setrecursionlimit(10**6) @@ -12,7 +14,7 @@ class ChakraHostTraceLoader: """Loads Chakra host traces.""" - def load(self, chakra_host_trace_file: str) -> List[PyTorchOperator]: + def load(self, chakra_host_trace_file: str) -> Tuple[List[PyTorchOperator], Dict[str, Any]]: """ Load and process the Chakra Host Execution Trace. @@ -20,44 +22,62 @@ def load(self, chakra_host_trace_file: str) -> List[PyTorchOperator]: chakra_host_trace_file (str): Path to the PyTorch execution trace file. Returns: - List[PyTorchOperator]: List of PyTorch operators. + Tuple[List[PyTorchOperator], Dict[str, Any]]: A tuple containing a list of PyTorch operators and a host trace. """ logging.debug(f"Starting to load Chakra host execution trace from file: {chakra_host_trace_file}.") - chakra_host_trace = load_execution_trace_file(chakra_host_trace_file) + host_trace = read_dictionary_from_json_file(chakra_host_trace_file) - # root_node = chakra_host_trace.get_nodes()[1] # Root node is usually 1-based - # chakra_host_ops = self.extract_chakra_host_ops(root_node) + schema: str = host_trace["schema"] + pid: int = host_trace["pid"] + nodes: List[Dict[str, Any]] = host_trace["nodes"] + + create_operator = self._get_operator_creation_method(schema) + if create_operator is None: + raise ValueError( + f"No corresponding node creation function found for schema version {schema}" + ) + + host_ops: Dict[int, PyTorchOperator] = {} + thread_roots: Dict[int, int] = {} + for node in nodes: + op = create_operator(pid, node) + if op.parent_id == 1 and EXECUTION_TRACE_THREAD_ANNOTATION in op.name: + thread_roots[op.tid] = op.id + host_ops[op.id] = op + + for op in host_ops.values(): + if op.parent_id not in host_ops: + parent_id = thread_roots.get(op.tid, None) + if parent_id is not None: + op.parent_id = parent_id + op.set_parent(host_ops[parent_id]) + host_ops[parent_id].add_child(op) + node = next(filter(lambda n: n["id"] == op.id, nodes), None) + if node is not None: + node["ctrl_deps"] = parent_id - # also include orphaned node belonging secondary connected components - chakra_host_ops = [] - for host_op in chakra_host_trace.get_nodes().values(): - chakra_host_ops.append(host_op) + for op in host_ops.values(): + op.sort_children() + + chakra_host_ops = sorted(host_ops.values(), key=lambda x: x.id) logging.debug(f"Extracted {len(chakra_host_ops)} operators from Chakra host execution trace.") logging.debug("Chakra host execution trace has been loaded and processed successfully.") - return chakra_host_ops - - def extract_chakra_host_ops(self, node: PyTorchOperator) -> List[PyTorchOperator]: - """ - Extract and sort nodes from the PyTorch execution trace recursively. - - This method traverses the execution trace starting from the provided node, extracting all the operator nodes - recursively, and then returns them sorted by their identifiers. - - Args: - node (PyTorchOperator): Starting node for extraction. - - Returns: - List[PyTorchOperator]: Sorted list of extracted PyTorchOperator nodes. - """ - nodes = [] - - def traverse(node: PyTorchOperator): - nodes.append(node) - for child in node.children: - traverse(child) - - traverse(node) - logging.debug(f"Traversed {len(nodes)} nodes from root node ID: {node.id}") - return sorted(nodes, key=lambda x: x.id) + return chakra_host_ops, host_trace + + def _get_operator_creation_method(self, schema: str) -> Callable[[int, Dict[str, Any]], PyTorchOperator] | None: + node_creation_func = { + "1.0.1": PyTorchTrace._create_node_v1_0_1, + "1.0.2-chakra.0.0.4": PyTorchTrace._create_node_v1_0_2_chakra_0_0_4, + # 1.0.3 expands pg name to so it use the same parser as 1.0.2 + "1.0.3-chakra.0.0.4": PyTorchTrace._create_node_v1_0_2_chakra_0_0_4, + # 1.0.4 adds PT2 kernel backend and kernel file + "1.0.4-chakra.0.0.4": PyTorchTrace._create_node_v1_0_2_chakra_0_0_4, + # 1.1.0 includes new comm args in record_param_comms + "1.1.0-chakra.0.0.4": PyTorchTrace._create_node_v1_0_2_chakra_0_0_4, + # 1.1.1 includes tensor strides + "1.1.1-chakra.0.0.4": PyTorchTrace._create_node_v1_1_1_chakra_0_0_4, + # Add future versions here + } + return node_creation_func.get(schema, None) diff --git a/src/trace_link/trace_linker.py b/src/trace_link/trace_linker.py index 1123e45d..c45322cb 100644 --- a/src/trace_link/trace_linker.py +++ b/src/trace_link/trace_linker.py @@ -4,7 +4,7 @@ import logging import os from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from et_replay.execution_trace import ( EXECUTION_TRACE_PROCESS_ANNOTATION, @@ -46,7 +46,7 @@ def link(self, rank: int, chakra_host_trace: str, chakra_device_trace: str, outp chakra_device_trace (str): Path to the Kineto trace file. output_file (str): Path for the output nyTorch execution trace plus file. """ - host_ops = self.chakra_host_trace_loader.load(chakra_host_trace) + host_ops, host_trace = self.chakra_host_trace_loader.load(chakra_host_trace) ( kineto_cpu_ops, @@ -77,7 +77,7 @@ def link(self, rank: int, chakra_host_trace: str, chakra_device_trace: str, outp ) chakra_execution_trace_plus_data = self.link_traces( - chakra_host_trace, + host_trace, host_ops, kineto_cpu_ops, sorted_kineto_cpu_ops, @@ -376,7 +376,7 @@ def find_closest_start_kineto_op( def link_traces( self, - chakra_host_trace: str, + host_trace: Dict[str, Any], host_ops: List[PyTorchOperator], kineto_cpu_ops: List[KinetoOperator], sorted_kineto_cpu_ops: List[KinetoOperator], @@ -393,7 +393,7 @@ def link_traces( Link Chakra Host ET and Chakra Device ET to produce an enhanced Chakra ET (ET +). Args: - chakra_host_trace (str): Path to the Chakra host execution trace file. + host_trace (Dict[str, Any]): The Chakra host execution trace. host_ops (List[PyTorchOperator]): List of Chakra host operators. kineto_cpu_ops (List[KinetoOperator]): List of Kineto CPU operators. sorted_kineto_cpu_ops (List[KinetoOperator]): Sorted list of Kineto CPU operators. @@ -442,7 +442,7 @@ def link_traces( kineto_external_id_to_kineto_op_map, ) chakra_execution_trace_plus_data = self.construct_et_plus_data( - chakra_host_trace, + host_trace, host_op_id_to_kineto_ops_map, host_op_id_to_inclusive_dur_map, host_op_id_to_exclusive_dur_map, @@ -822,7 +822,7 @@ def link_gpu_ops(self, host_op: PyTorchOperator, kineto_gpu_ops: List[KinetoOper def construct_et_plus_data( self, - chakra_host_trace: str, + host_trace: Dict[str, Any], host_op_id_to_kineto_ops_map: Dict[int, List[KinetoOperator]], host_op_id_to_inclusive_dur_map: Dict[int, int], host_op_id_to_exclusive_dur_map: Dict[int, int], @@ -836,7 +836,7 @@ def construct_et_plus_data( offering a comprehensive view of the execution. Args: - chakra_host_trace (str): Path to the Chakra host execution trace file. + host_trace (Dict[str, Any]): The Chakra host execution trace. host_op_id_to_kineto_ops_map (Dict[int, List[KinetoOperator]]): Map from Chakra host op IDs to Kineto GPU ops. host_op_id_to_inclusive_dur_map (Dict[int, int]): Inclusive duration map for Chakra host ops. @@ -849,10 +849,8 @@ def construct_et_plus_data( Dict: The constructed ET+ data. """ logging.debug("Constructing ET+ data.") - with open(chakra_host_trace, "r") as file: - pytorch_et_data = json.load(file) - sorted_nodes = sorted(pytorch_et_data["nodes"], key=lambda x: x["id"]) + sorted_nodes = sorted(host_trace["nodes"], key=lambda x: x["id"]) gpu_ops = [] for op in sorted_nodes: gpu_ops += self.process_op_and_dependents( @@ -863,7 +861,7 @@ def construct_et_plus_data( host_op_id_to_timestamp_map, host_op_id_to_inter_thread_dep_map, ) - pytorch_et_data["nodes"] += gpu_ops + host_trace["nodes"] += gpu_ops # Add sync dependencies sync_dep_mapping = {} @@ -876,7 +874,7 @@ def construct_et_plus_data( del gpu_op["sync_dep_to"] # Update parent-child relationships with new IDs - sorted_nodes = sorted(pytorch_et_data["nodes"], key=lambda x: x["id"]) + sorted_nodes = sorted(host_trace["nodes"], key=lambda x: x["id"]) for op in sorted_nodes: for key in sync_dep_mapping: if self.id_assigner.lookup_new_id(key) == op["id"]: @@ -884,7 +882,7 @@ def construct_et_plus_data( if "ctrl_deps" in op: op["ctrl_deps"] = self.id_assigner.assign_or_retrieve_id(op["ctrl_deps"]) - return pytorch_et_data + return host_trace def process_op_and_dependents( self, From 2621cf9b5d41e115d7ad4551cb343d8d0e862a54 Mon Sep 17 00:00:00 2001 From: Theodor Badea Date: Wed, 28 May 2025 18:00:50 +0000 Subject: [PATCH 3/8] linter --- src/trace_link/chakra_host_trace_loader.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/trace_link/chakra_host_trace_loader.py b/src/trace_link/chakra_host_trace_loader.py index 003fc53f..8d1a09e5 100644 --- a/src/trace_link/chakra_host_trace_loader.py +++ b/src/trace_link/chakra_host_trace_loader.py @@ -2,9 +2,9 @@ import sys from typing import Any, Callable, Dict, List, Tuple -from et_replay.execution_trace import Node as PyTorchOperator -from et_replay.execution_trace import ExecutionTrace as PyTorchTrace from et_replay.execution_trace import EXECUTION_TRACE_THREAD_ANNOTATION +from et_replay.execution_trace import ExecutionTrace as PyTorchTrace +from et_replay.execution_trace import Node as PyTorchOperator from et_replay.utils import read_dictionary_from_json_file # Increase the recursion limit for deep Chakra host execution traces. @@ -22,7 +22,7 @@ def load(self, chakra_host_trace_file: str) -> Tuple[List[PyTorchOperator], Dict chakra_host_trace_file (str): Path to the PyTorch execution trace file. Returns: - Tuple[List[PyTorchOperator], Dict[str, Any]]: A tuple containing a list of PyTorch operators and a host trace. + Tuple[List[PyTorchOperator], Dict[str, Any]]: Tuple containing list of PyTorch operators and host trace. """ logging.debug(f"Starting to load Chakra host execution trace from file: {chakra_host_trace_file}.") host_trace = read_dictionary_from_json_file(chakra_host_trace_file) @@ -80,4 +80,4 @@ def _get_operator_creation_method(self, schema: str) -> Callable[[int, Dict[str, "1.1.1-chakra.0.0.4": PyTorchTrace._create_node_v1_1_1_chakra_0_0_4, # Add future versions here } - return node_creation_func.get(schema, None) + return node_creation_func.get(schema) From e788fe6292394ba9ce2d926de1563bf5115ec0cf Mon Sep 17 00:00:00 2001 From: Theodor Badea Date: Thu, 29 May 2025 12:05:25 +0000 Subject: [PATCH 4/8] add under flag --- src/trace_link/chakra_host_trace_loader.py | 105 ++++++++++++++++----- src/trace_link/trace_link.py | 7 +- src/trace_link/trace_linker.py | 9 +- 3 files changed, 94 insertions(+), 27 deletions(-) diff --git a/src/trace_link/chakra_host_trace_loader.py b/src/trace_link/chakra_host_trace_loader.py index 8d1a09e5..6458191e 100644 --- a/src/trace_link/chakra_host_trace_loader.py +++ b/src/trace_link/chakra_host_trace_loader.py @@ -2,7 +2,7 @@ import sys from typing import Any, Callable, Dict, List, Tuple -from et_replay.execution_trace import EXECUTION_TRACE_THREAD_ANNOTATION +from et_replay.execution_trace import EXECUTION_TRACE_THREAD_ANNOTATION as THREAD_ANNOTATION from et_replay.execution_trace import ExecutionTrace as PyTorchTrace from et_replay.execution_trace import Node as PyTorchOperator from et_replay.utils import read_dictionary_from_json_file @@ -14,19 +14,67 @@ class ChakraHostTraceLoader: """Loads Chakra host traces.""" - def load(self, chakra_host_trace_file: str) -> Tuple[List[PyTorchOperator], Dict[str, Any]]: + def load(self, + chakra_host_trace_file: str, + connect_host_trace: bool) -> Tuple[List[PyTorchOperator], Dict[str, Any]]: """ Load and process the Chakra Host Execution Trace. Args: chakra_host_trace_file (str): Path to the PyTorch execution trace file. - + connect_host_trace (bool): Connect host nodes with missing parents to the corresponding thread root node. Returns: Tuple[List[PyTorchOperator], Dict[str, Any]]: Tuple containing list of PyTorch operators and host trace. """ logging.debug(f"Starting to load Chakra host execution trace from file: {chakra_host_trace_file}.") host_trace = read_dictionary_from_json_file(chakra_host_trace_file) + host_ops = self._create_host_ops(host_trace, connect_host_trace) + root_node = host_ops.get(1) # Root node is usually 1-based + + chakra_host_ops = self.extract_chakra_host_ops(root_node) + + logging.debug(f"Extracted {len(chakra_host_ops)} operators from Chakra host execution trace.") + logging.debug("Chakra host execution trace has been loaded and processed successfully.") + + return chakra_host_ops, host_trace + + def extract_chakra_host_ops(self, node: PyTorchOperator) -> List[PyTorchOperator]: + """ + Extract and sort nodes from the PyTorch execution trace recursively. + + This method traverses the execution trace starting from the provided node, extracting all the operator nodes + recursively, and then returns them sorted by their identifiers. + + Args: + node (PyTorchOperator): Starting node for extraction. + Returns: + List[PyTorchOperator]: Sorted list of extracted PyTorchOperator nodes. + """ + nodes = [] + + def traverse(node: PyTorchOperator): + nodes.append(node) + for child in node.children: + traverse(child) + + traverse(node) + logging.debug(f"Traversed {len(nodes)} nodes from root node ID: {node.id}") + return sorted(nodes, key=lambda x: x.id) + + def _create_host_ops(self, host_trace: Dict[str, Any], connect_host_trace: bool) -> Dict[int, PyTorchOperator]: + """ + Create host operators from the provided host trace. + This method processes the host trace, extracts nodes, and creates PyTorchOperator instances based on the schema + version specified in the host trace. + + Args: + host_trace (Dict[str, Any]): The host trace dictionary. + connect_host_trace (bool): Connect host nodes with missing parents to the corresponding thread root node. + Returns: + Dict[int, PyTorchOperator]: A dictionary mapping operator IDs to PyTorchOperator instances. + """ + schema: str = host_trace["schema"] pid: int = host_trace["pid"] nodes: List[Dict[str, Any]] = host_trace["nodes"] @@ -40,33 +88,42 @@ def load(self, chakra_host_trace_file: str) -> Tuple[List[PyTorchOperator], Dict host_ops: Dict[int, PyTorchOperator] = {} thread_roots: Dict[int, int] = {} for node in nodes: - op = create_operator(pid, node) - if op.parent_id == 1 and EXECUTION_TRACE_THREAD_ANNOTATION in op.name: - thread_roots[op.tid] = op.id - host_ops[op.id] = op - - for op in host_ops.values(): - if op.parent_id not in host_ops: - parent_id = thread_roots.get(op.tid, None) + host_op = create_operator(pid, node) + host_ops[host_op.id] = host_op + if host_op.parent_id == 1 and THREAD_ANNOTATION in host_op.name: + thread_roots[host_op.tid] = host_op.id + + for host_op in host_ops.values(): + if host_op.parent_id in host_ops and host_op.id != 1: + parent = host_ops[host_op.parent_id] + host_op.set_parent(parent) + parent.add_child(host_op) + elif connect_host_trace is True: # connect orphans to the thread root + parent_id = thread_roots.get(host_op.tid, None) if parent_id is not None: - op.parent_id = parent_id - op.set_parent(host_ops[parent_id]) - host_ops[parent_id].add_child(op) - node = next(filter(lambda n: n["id"] == op.id, nodes), None) + host_op.parent_id = parent_id + parent = host_ops[parent_id] + host_op.set_parent(parent) + parent.add_child(host_op) + node = next(filter(lambda n: n["id"] == host_op.id, nodes), None) if node is not None: node["ctrl_deps"] = parent_id - for op in host_ops.values(): - op.sort_children() - - chakra_host_ops = sorted(host_ops.values(), key=lambda x: x.id) - - logging.debug(f"Extracted {len(chakra_host_ops)} operators from Chakra host execution trace.") - logging.debug("Chakra host execution trace has been loaded and processed successfully.") - - return chakra_host_ops, host_trace + for host_op in host_ops.values(): + host_op.sort_children() + return host_ops + def _get_operator_creation_method(self, schema: str) -> Callable[[int, Dict[str, Any]], PyTorchOperator] | None: + """ + Get the operator creation method for the specified schema version. + + Args: + schema (str): The schema version of the host trace. + Returns: + Callable[[int, Dict[str, Any]], PyTorchOperator] | None: The operator creation functor for the schema version, + or None if no functor is found. + """ node_creation_func = { "1.0.1": PyTorchTrace._create_node_v1_0_1, "1.0.2-chakra.0.0.4": PyTorchTrace._create_node_v1_0_2_chakra_0_0_4, diff --git a/src/trace_link/trace_link.py b/src/trace_link/trace_link.py index 12074df5..99299e5e 100644 --- a/src/trace_link/trace_link.py +++ b/src/trace_link/trace_link.py @@ -37,6 +37,11 @@ def main() -> None: required=True, help="Path for the output Chakra host + device trace in the JSON format", ) + parser.add_argument("--connect-host-trace", + type=bool, + default=False, + help="Whether to connect host nodes with missing parents to the corresponding thread root node.", + ) parser.add_argument("--log-level", default="INFO", type=str, help="Log output verbosity level") args = parser.parse_args() @@ -44,7 +49,7 @@ def main() -> None: logging.basicConfig(level=args.log_level.upper()) linker = TraceLinker() - linker.link(args.rank, args.chakra_host_trace, args.chakra_device_trace, args.output_file) + linker.link(args.rank, args.chakra_host_trace, args.chakra_device_trace, args.output_file, args.connect_host_trace) logging.info(f"Linking process successful. Output file is available at {args.output_file}.") logging.info("Please run the chakra_converter for further postprocessing.") diff --git a/src/trace_link/trace_linker.py b/src/trace_link/trace_linker.py index c45322cb..6f957610 100644 --- a/src/trace_link/trace_linker.py +++ b/src/trace_link/trace_linker.py @@ -36,7 +36,11 @@ def __init__(self) -> None: self.chakra_device_trace_loader = ChakraDeviceTraceLoader() self.id_assigner = UniqueIdAssigner() - def link(self, rank: int, chakra_host_trace: str, chakra_device_trace: str, output_file: str) -> None: + def link(self, rank: int, + chakra_host_trace: str, + chakra_device_trace: str, + output_file: str, + connect_host_trace: bool) -> None: """ Links Chakra host execution traces (ET) and Chakra device ET to generate Chakra host + device ET. @@ -45,8 +49,9 @@ def link(self, rank: int, chakra_host_trace: str, chakra_device_trace: str, outp chakra_host_trace (str): Path to the Chakra host execution trace file. chakra_device_trace (str): Path to the Kineto trace file. output_file (str): Path for the output nyTorch execution trace plus file. + connect_host_trace (bool): Connect host nodes with missing parents to the corresponding thread root node. """ - host_ops, host_trace = self.chakra_host_trace_loader.load(chakra_host_trace) + host_ops, host_trace = self.chakra_host_trace_loader.load(chakra_host_trace, connect_host_trace) ( kineto_cpu_ops, From 15ac783c5647c3960d8733c4002b9ad43bd7f1ff Mon Sep 17 00:00:00 2001 From: Theodor Badea Date: Thu, 29 May 2025 12:09:08 +0000 Subject: [PATCH 5/8] revert changes to converter --- src/converter/pytorch_converter.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/converter/pytorch_converter.py b/src/converter/pytorch_converter.py index 5eabe96a..ea383a51 100644 --- a/src/converter/pytorch_converter.py +++ b/src/converter/pytorch_converter.py @@ -49,8 +49,7 @@ def convert(self, input_filename: str, output_filename: str, simulate: bool) -> for root_node in root_node_list: self.convert_ctrl_dep_to_data_dep(json_node_map, protobuf_node_map, root_node) - # do not remove secondary connected components - # protobuf_node_map = self.remove_dangling_nodes(protobuf_node_map) + protobuf_node_map = self.remove_dangling_nodes(protobuf_node_map) parent_to_children_map = self.update_parent_to_children_map(protobuf_node_map) From 4edf3bacb5888228ce6d96d6b5e4460d38a45b57 Mon Sep 17 00:00:00 2001 From: Theodor Badea Date: Thu, 29 May 2025 12:17:07 +0000 Subject: [PATCH 6/8] fix lint --- src/trace_link/chakra_host_trace_loader.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/trace_link/chakra_host_trace_loader.py b/src/trace_link/chakra_host_trace_loader.py index 6458191e..e1266c30 100644 --- a/src/trace_link/chakra_host_trace_loader.py +++ b/src/trace_link/chakra_host_trace_loader.py @@ -23,6 +23,7 @@ def load(self, Args: chakra_host_trace_file (str): Path to the PyTorch execution trace file. connect_host_trace (bool): Connect host nodes with missing parents to the corresponding thread root node. + Returns: Tuple[List[PyTorchOperator], Dict[str, Any]]: Tuple containing list of PyTorch operators and host trace. """ @@ -48,6 +49,7 @@ def extract_chakra_host_ops(self, node: PyTorchOperator) -> List[PyTorchOperator Args: node (PyTorchOperator): Starting node for extraction. + Returns: List[PyTorchOperator]: Sorted list of extracted PyTorchOperator nodes. """ @@ -65,12 +67,14 @@ def traverse(node: PyTorchOperator): def _create_host_ops(self, host_trace: Dict[str, Any], connect_host_trace: bool) -> Dict[int, PyTorchOperator]: """ Create host operators from the provided host trace. + This method processes the host trace, extracts nodes, and creates PyTorchOperator instances based on the schema version specified in the host trace. Args: host_trace (Dict[str, Any]): The host trace dictionary. connect_host_trace (bool): Connect host nodes with missing parents to the corresponding thread root node. + Returns: Dict[int, PyTorchOperator]: A dictionary mapping operator IDs to PyTorchOperator instances. """ @@ -119,9 +123,10 @@ def _get_operator_creation_method(self, schema: str) -> Callable[[int, Dict[str, Get the operator creation method for the specified schema version. Args: - schema (str): The schema version of the host trace. + schema (str): The schema version of the host trace. + Returns: - Callable[[int, Dict[str, Any]], PyTorchOperator] | None: The operator creation functor for the schema version, + Callable[[int, Dict[str, Any]], PyTorchOperator] | None: Operator creation functor for the schema version, or None if no functor is found. """ node_creation_func = { From be26b925ed774926a250fb41f2c242859c147987 Mon Sep 17 00:00:00 2001 From: Theodor Badea Date: Thu, 29 May 2025 12:18:47 +0000 Subject: [PATCH 7/8] lint --- src/trace_link/chakra_host_trace_loader.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/trace_link/chakra_host_trace_loader.py b/src/trace_link/chakra_host_trace_loader.py index e1266c30..ef173f94 100644 --- a/src/trace_link/chakra_host_trace_loader.py +++ b/src/trace_link/chakra_host_trace_loader.py @@ -77,8 +77,7 @@ def _create_host_ops(self, host_trace: Dict[str, Any], connect_host_trace: bool) Returns: Dict[int, PyTorchOperator]: A dictionary mapping operator IDs to PyTorchOperator instances. - """ - + """ schema: str = host_trace["schema"] pid: int = host_trace["pid"] nodes: List[Dict[str, Any]] = host_trace["nodes"] From a85226b42d65173ddb66fb6a7ba6ba3b112fc30e Mon Sep 17 00:00:00 2001 From: Theodor Badea Date: Thu, 29 May 2025 13:01:03 +0000 Subject: [PATCH 8/8] test refactor --- tests/trace_link/test_trace_linker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trace_link/test_trace_linker.py b/tests/trace_link/test_trace_linker.py index a0441ae4..ca8ccc9c 100644 --- a/tests/trace_link/test_trace_linker.py +++ b/tests/trace_link/test_trace_linker.py @@ -594,7 +594,7 @@ def test_construct_et_plus_data(mock_json_load, mock_open, mock_process_op_and_d host_op_id_to_inter_thread_dep_map = {1: None, 2: None} pytorch_et_plus_data = trace_linker.construct_et_plus_data( - "path/to/pytorch_et_file", + mock_json_load.return_value, host_op_id_to_kineto_ops_map, host_op_id_to_inclusive_dur_map, host_op_id_to_exclusive_dur_map,