8000 preserve collectives by theodorbadea · Pull Request #190 · mlcommons/chakra · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

preserve collectives #190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 94 additions & 7 deletions src/trace_link/chakra_host_trace_loader.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
import sys
from typing import List
from typing import Any, Callable, Dict, List, Tuple

from et_replay.execution_trace import EXECUTION_TRACE_THREAD_ANNOTATION as THREAD_ANNOTATION
from et_replay.execution_trace import ExecutionTrace as PyTorchTrace
from et_replay.execution_trace import Node as PyTorchOperator
from et_replay.utils import load_execution_trace_file
from et_replay.utils import read_dictionary_from_json_file

# Increase the recursion limit for deep Chakra host execution traces.
sys.setrecursionlimit(10**6)
Expand All @@ -12,25 +14,31 @@
class ChakraHostTraceLoader:
"""Loads Chakra host traces."""

def load(self, chakra_host_trace_file: str) -> List[PyTorchOperator]:
def load(self,
chakra_host_trace_file: str,
connect_host_trace: bool) -> Tuple[List[PyTorchOperator], Dict[str, Any]]:
"""
Load and process the Chakra Host Execution Trace.

Args:
chakra_host_trace_file (str): Path to the PyTorch execution trace file.
connect_host_trace (bool): Connect host nodes with missing parents to the corresponding thread root node.

Returns:
List[PyTorchOperator]: List of PyTorch operators.
Tuple[List[PyTorchOperator], Dict[str, Any]]: Tuple containing list of PyTorch operators and host trace.
"""
logging.debug(f"Starting to load Chakra host execution trace from file: {chakra_host_trace_file}.")
chakra_host_trace = load_execution_trace_file(chakra_host_trace_file)
host_trace = read_dictionary_from_json_file(chakra_host_trace_file)

root_node = chakra_host_trace.get_nodes()[1] # Root node is usually 1-based
host_ops = self._create_host_ops(host_trace, connect_host_trace)
root_node = host_ops.get(1) # Root node is usually 1-based

chakra_host_ops = self.extract_chakra_host_ops(root_node)

logging.debug(f"Extracted {len(chakra_host_ops)} operators from Chakra host execution trace.")
logging.debug("Chakra host execution trace has been loaded and processed successfully.")

return chakra_host_ops
return chakra_host_ops, host_trace

def extract_chakra_host_ops(self, node: PyTorchOperator) -> List[PyTorchOperator]:
"""
Expand All @@ -55,3 +63,82 @@ def traverse(node: PyTorchOperator):
traverse(node)
logging.debug(f"Traversed {len(nodes)} nodes from root node ID: {node.id}")
return sorted(nodes, key=lambda x: x.id)

def _create_host_ops(self, host_trace: Dict[str, Any], connect_host_trace: bool) -> Dict[int, PyTorchOperator]:
"""
Create host operators from the provided host trace.

This method processes the host trace, extracts nodes, and creates PyTorchOperator instances based on the schema
version specified in the host trace.

Args:
host_trace (Dict[str, Any]): The host trace dictionary.
connect_host_trace (bool): Connect host nodes with missing parents to the corresponding thread root node.

Returns:
Dict[int, PyTorchOperator]: A dictionary mapping operator IDs to PyTorchOperator instances.
"""
schema: str = host_trace["schema"]
pid: int = host_trace["pid"]
nodes: List[Dict[str, Any]] = host_trace["nodes"]

create_operator = self._get_operator_creation_method(schema)
if create_operator is None:
raise ValueError(
f"No corresponding node creation function found for schema version {schema}"
)

host_ops: Dict[int, PyTorchOperator] = {}
thread_roots: Dict[int, int] = {}
for node in nodes:
host_op = create_operator(pid, node)
host_ops[host_op.id] = host_op
if host_op.parent_id == 1 and THREAD_ANNOTATION in host_op.name:
thread_roots[host_op.tid] = host_op.id

for host_op in host_ops.values():
if host_op.parent_id in host_ops and host_op.id != 1:
parent = host_ops[host_op.parent_id]
host_op.set_parent(parent)
parent.add_child(host_op)
elif connect_host_trace is True: # connect orphans to the thread root
parent_id = thread_roots.get(host_op.tid, None)
if parent_id is not None:
host_op.parent_id = parent_id
parent = host_ops[parent_id]
host_op.set_parent(parent)
parent.add_child(host_op)
node = next(filter(lambda n: n["id"] == host_op.id, nodes), None)
if node is not None:
node["ctrl_deps"] = parent_id

for host_op in host_ops.values():
host_op.sort_children()

return host_ops

def _get_operator_creation_method(self, schema: str) -> Callable[[int, Dict[str, Any]], PyTorchOperator] | None:
"""
Get the operator creation method for the specified schema version.

Args:
schema (str): The schema version of the host trace.

Returns:
Callable[[int, Dict[str, Any]], PyTorchOperator] | None: Operator creation functor for the schema version,
or None if no functor is found.
"""
node_creation_func = {
"1.0.1": PyTorchTrace._create_node_v1_0_1,
"1.0.2-chakra.0.0.4": PyTorchTrace._create_node_v1_0_2_chakra_0_0_4,
# 1.0.3 expands pg name to <pg_name, pg_desc> so it use the same parser as 1.0.2
"1.0.3-chakra.0.0.4": PyTorchTrace._create_node_v1_0_2_chakra_0_0_4,
# 1.0.4 adds PT2 kernel backend and kernel file
"1.0.4-chakra.0.0.4": PyTorchTrace._create_node_v1_0_2_chakra_0_0_4,
# 1.1.0 includes new comm args in record_param_comms
"1.1.0-chakra.0.0.4": PyTorchTrace._create_node_v1_0_2_chakra_0_0_4,
# 1.1.1 includes tensor strides
"1.1.1-chakra.0.0.4": PyTorchTrace._create_node_v1_1_1_chakra_0_0_4,
# Add future versions here
}
return node_creation_func.get(schema)
7 changes: 6 additions & 1 deletion src/trace_link/trace_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,19 @@ def main() -> None:
required=True,
help="Path for the output Chakra host + device trace in the JSON format",
)
parser.add_argument("--connect-host-trace",
type=bool,
default=False,
help="Whether to connect host nodes with missing parents to the corresponding thread root node.",
)
parser.add_argument("--log-level", default="INFO", type=str, help="Log output verbosity level")

args = parser.parse_args()

logging.basicConfig(level=args.log_level.upper())

linker = TraceLinker()
linker.link(args.rank, args.chakra_host_trace, args.chakra_device_trace, args.output_file)
linker.link(args.rank, args.chakra_host_trace, args.chakra_device_trace, args.output_file, args.connect_host_trace)

logging.info(f"Linking process successful. Output file is available at {args.output_file}.")
logging.info("Please run the chakra_converter for further postprocessing.")
Expand Down
33 changes: 18 additions & 15 deletions src/trace_link/trace_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

from et_replay.execution_trace import (
EXECUTION_TRACE_PROCESS_ANNOTATION,
Expand Down Expand Up @@ -36,7 +36,11 @@ def __init__(self) -> None:
self.chakra_device_trace_loader = ChakraDeviceTraceLoader()
self.id_assigner = UniqueIdAssigner()

def link(self, rank: int, chakra_host_trace: str, chakra_device_trace: str, output_file: str) -> None:
def link(self, rank: int,
chakra_host_trace: str,
chakra_device_trace: str,
output_file: str,
connect_host_trace: bool) -> None:
"""
Links Chakra host execution traces (ET) and Chakra device ET to generate Chakra host + device ET.

Expand All @@ -45,8 +49,9 @@ def link(self, rank: int, chakra_host_trace: str, chakra_device_trace: str, outp
chakra_host_trace (str): Path to the Chakra host execution trace file.
chakra_device_trace (str): Path to the Kineto trace file.
output_file (str): Path for the output nyTorch execution trace plus file.
connect_host_trace (bool): Connect host nodes with missing parents to the corresponding thread root node.
"""
host_ops = self.chakra_host_trace_loader.load(chakra_host_trace)
host_ops, host_trace = self.chakra_host_trace_loader.load(chakra_host_trace, connect_host_trace)

(
kineto_cpu_ops,
Expand Down Expand Up @@ -77,7 +82,7 @@ def link(self, rank: int, chakra_host_trace: str, chakra_device_trace: str, outp
)

chakra_execution_trace_plus_data = self.link_traces(
chakra_host_trace,
host_trace,
host_ops,
kineto_cpu_ops,
sorted_kineto_cpu_ops,
Expand Down Expand Up @@ -376,7 +381,7 @@ def find_closest_start_kineto_op(

def link_traces(
self,
chakra_host_trace: str,
host_trace: Dict[str, Any],
host_ops: List[PyTorchOperator],
kineto_cpu_ops: List[KinetoOperator],
sorted_kineto_cpu_ops: List[KinetoOperator],
Expand All @@ -393,7 +398,7 @@ def link_traces(
Link Chakra Host ET and Chakra Device ET to produce an enhanced Chakra ET (ET +).

Args:
chakra_host_trace (str): Path to the Chakra host execution trace file.
host_trace (Dict[str, Any]): The Chakra host execution trace.
host_ops (List[PyTorchOperator]): List of Chakra host operators.
kineto_cpu_ops (List[KinetoOperator]): List of Kineto CPU operators.
sorted_kineto_cpu_ops (List[KinetoOperator]): Sorted list of Kineto CPU operators.
Expand Down Expand Up @@ -442,7 +447,7 @@ def link_traces(
kineto_external_id_to_kineto_op_map,
)
chakra_execution_trace_plus_data = self.construct_et_plus_data(
chakra_host_trace,
host_trace,
host_op_id_to_kineto_ops_map,
host_op_id_to_inclusive_dur_map,
host_op_id_to_exclusive_dur_map,
Expand Down Expand Up @@ -822,7 +827,7 @@ def link_gpu_ops(self, host_op: PyTorchOperator, kineto_gpu_ops: List[KinetoOper

def construct_et_plus_data(
self,
chakra_host_trace: str,
host_trace: Dict[str, Any],
host_op_id_to_kineto_ops_map: Dict[int, List[KinetoOperator]],
host_op_id_to_inclusive_dur_map: Dict[int, int],
host_op_id_to_exclusive_dur_map: Dict[int, int],
Expand All @@ -836,7 +841,7 @@ def construct_et_plus_data(
offering a comprehensive view of the execution.

Args:
chakra_host_trace (str): Path to the Chakra host execution trace file.
host_trace (Dict[str, Any]): The Chakra host execution trace.
host_op_id_to_kineto_ops_map (Dict[int, List[KinetoOperator]]): Map from Chakra host op IDs to Kineto
GPU ops.
host_op_id_to_inclusive_dur_map (Dict[int, int]): Inclusive duration map for Chakra host ops.
Expand All @@ -849,10 +854,8 @@ def construct_et_plus_data(
Dict: The constructed ET+ data.
"""
logging.debug("Constructing ET+ data.")
with open(chakra_host_trace, "r") as file:
pytorch_et_data = json.load(file)

sorted_nodes = sorted(pytorch_et_data["nodes"], key=lambda x: x["id"])
sorted_nodes = sorted(host_trace["nodes"], key=lambda x: x["id"])
gpu_ops = []
for op in sorted_nodes:
gpu_ops += self.process_op_and_dependents(
Expand All @@ -863,7 +866,7 @@ def construct_et_plus_data(
host_op_id_to_timestamp_map,
host_op_id_to_inter_thread_dep_map,
)
pytorch_et_data["nodes"] += gpu_ops
host_trace["nodes"] += gpu_ops

# Add sync dependencies
sync_dep_mapping = {}
Expand All @@ -876,15 +879,15 @@ def construct_et_plus_data(
del gpu_op["sync_dep_to"]

# Update parent-child relationships with new IDs
sorted_nodes = sorted(pytorch_et_data["nodes"], key=lambda x: x["id"])
sorted_nodes = sorted(host_trace["nodes"], key=lambda x: x["id"])
for op in sorted_nodes:
for key in sync_dep_mapping:
if self.id_assigner.lookup_new_id(key) == op["id"]:
op["sync_dep"] = sync_dep_mapping[key]
if "ctrl_deps" in op:
op["ctrl_deps"] = self.id_assigner.assign_or_retrieve_id(op["ctrl_deps"])

return pytorch_et_data
return host_trace

def process_op_and_dependents(
self,
Expand Down
2 changes: 1 addition & 1 deletion tests/trace_link/test_trace_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def test_construct_et_plus_data(mock_json_load, mock_open, mock_process_op_and_d
host_op_id_to_inter_thread_dep_map = {1: None, 2: None}

pytorch_et_plus_data = trace_linker.construct_et_plus_data(
"path/to/pytorch_et_file",
mock_json_load.return_value,
host_op_id_to_kineto_ops_map,
host_op_id_to_inclusive_dur_map,
host_op_id_to_exclusive_dur_map,
Expand Down
0