From aacfcff9f6016f534639e5f9ee7cc12a95592918 Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Wed, 10 Jul 2024 14:23:09 -0400 Subject: [PATCH 1/5] Leverage HTA to support synchronization dependency in trace linking Co-authored-by: Joongun Park <8554137+JoongunPark@users.noreply.github.com> --- .github/workflows/python_tests.yml | 9 + src/converter/pytorch_converter.py | 12 +- src/converter/pytorch_node.py | 2 +- src/trace_link/chakra_device_trace_loader.py | 13 ++ src/trace_link/kineto_operator.py | 7 +- src/trace_link/trace_link.py | 4 +- src/trace_link/trace_linker.py | 221 ++++++++++++++++++- tests/trace_link/test_kineto_operator.py | 2 +- tests/trace_link/test_trace_linker.py | 195 ++++++++++++++++ 9 files changed, 456 insertions(+), 9 deletions(-) diff --git a/.github/workflows/python_tests.yml b/.github/workflows/python_tests.yml index 250cd7e0..a44a2e77 100644 --- a/.github/workflows/python_tests.yml +++ b/.github/workflows/python_tests.yml @@ -26,6 +26,15 @@ jobs: git checkout 7b19f586dd8b267333114992833a0d7e0d601630 pip install . + - name: Install HTA + run: | + git clone https://github.com/facebookresearch/HolisticTraceAnalysis.git + cd HolisticTraceAnalysis + git checkout d731cc2e2249976c97129d409a83bd53d93051f6 + git submodule update --init + pip install -r requirements.txt + pip install -e . + - name: Install Dependencies run: | pip install -r requirements-dev.txt diff --git a/src/converter/pytorch_converter.py b/src/converter/pytorch_converter.py index 2371dccb..9474c556 100644 --- a/src/converter/pytorch_converter.py +++ b/src/converter/pytorch_converter.py @@ -244,7 +244,11 @@ def convert_json_to_protobuf_nodes( [ ChakraAttr(name="comm_type", int64_val=collective_comm_type), ChakraAttr(name="comm_size", int64_val=pytorch_gpu_node.comm_size), - *( [ChakraAttr(name="pg_name", string_val=pytorch_gpu_node.pg_name)] if pytorch_gpu_node.pg_name != "" else [] ), + *( + [ChakraAttr(name="pg_name", string_val=pytorch_gpu_node.pg_name)] + if pytorch_gpu_node.pg_name != "" + else [] + ), ] ) @@ -252,7 +256,11 @@ def convert_json_to_protobuf_nodes( chakra_gpu_node.attr.extend( [ ChakraAttr(name="comm_size", int64_val=pytorch_gpu_node.comm_size), - *( [ChakraAttr(name="pg_name", string_val=pytorch_gpu_node.pg_name)] if pytorch_gpu_node.pg_name != "" else [] ), + *( + [ChakraAttr(name="pg_name", string_val=pytorch_gpu_node.pg_name)] + if pytorch_gpu_node.pg_name != "" + else [] + ), ] ) diff --git a/src/converter/pytorch_node.py b/src/converter/pytorch_node.py index b569bd51..50feb4ac 100644 --- a/src/converter/pytorch_node.py +++ b/src/converter/pytorch_node.py @@ -19,7 +19,7 @@ class PyTorchNodeType(Enum): CPU_OP = 1 GPU_OP = 2 LABEL = 3 # Non-operator nodes - METADATA = 4 # Metadata nodes + METADATA = 4 # Metadata nodes class PyTorchNode: diff --git a/src/trace_link/chakra_device_trace_loader.py b/src/trace_link/chakra_device_trace_loader.py index 70111c59..0737f132 100644 --- a/src/trace_link/chakra_device_trace_loader.py +++ b/src/trace_link/chakra_device_trace_loader.py @@ -16,6 +16,7 @@ def load( ) -> Tuple[ List[KinetoOperator], Dict[int, List[KinetoOperator]], + Dict[int, List[KinetoOperator]], Dict[int, KinetoOperator], List[KinetoOperator], Dict[int, KinetoOperator], @@ -26,6 +27,7 @@ def load( Dict[int, KinetoOperator], List[KinetoOperator], List[int], + Dict[int, KinetoOperator], ]: """ Load and process the Chakra device trace. @@ -57,6 +59,7 @@ def load( logging.debug("Chakra device trace has been loaded and processed successfully.") return ( dev_data["kineto_cpu_ops"], + dev_data["kineto_tid_ops_map"], dev_data["kineto_tid_cpu_ops_map"], dev_data["kineto_correlation_cuda_runtime_map"], dev_data["kineto_gpu_ops"], @@ -68,6 +71,7 @@ def load( dev_data["kineto_rf_id_to_kineto_op_map"], dev_data["sorted_kineto_cpu_ops"], dev_data["sorted_kineto_cpu_op_ts"], + dev_data["kineto_external_id_to_kineto_op_map"], ) def construct_dev_data_structures(self, kineto_ops: List[KinetoOperator], trace_file: str) -> Dict: @@ -90,13 +94,17 @@ def construct_dev_data_structures(self, kineto_ops: List[KinetoOperator], trace_ thread_info = {} kineto_cpu_ops = [] + kineto_tid_ops_map = {} kineto_tid_cpu_ops_map = {} kineto_correlation_cuda_runtime_map = {} kineto_gpu_ops = [] kineto_id_arrow_op_map = {} kineto_id_cuda_launch_op_map = {} + kineto_external_id_to_kineto_op_map = {} for op in kineto_ops: + kineto_tid_ops_map.setdefault(op.tid, []).append(op) + if op.is_cpu_op(): kineto_cpu_ops.append(op) kineto_tid_cpu_ops_map.setdefault(op.tid, []).append(op) @@ -144,10 +152,14 @@ def construct_dev_data_structures(self, kineto_ops: List[KinetoOperator], trace_ thread_start_end[0] = min(thread_start_end[0], op.timestamp) thread_start_end[1] = max(thread_start_end[1], op.timestamp + op.inclusive_dur) + if op.external_id is not None: + kineto_external_id_to_kineto_op_map[op.external_id] = op + kineto_rf_id_to_kineto_op_map = {op.rf_id: op for op in kineto_cpu_ops if op.rf_id is not None} return { "kineto_cpu_ops": kineto_cpu_ops, + "kineto_tid_ops_map": kineto_tid_ops_map, "kineto_tid_cpu_ops_map": kineto_tid_cpu_ops_map, "kineto_correlation_cuda_runtime_map": kineto_correlation_cuda_runtime_map, "kineto_gpu_ops": kineto_gpu_ops, @@ -159,6 +171,7 @@ def construct_dev_data_structures(self, kineto_ops: List[KinetoOperator], trace_ "kineto_rf_id_to_kineto_op_map": kineto_rf_id_to_kineto_op_map, "sorted_kineto_cpu_ops": [], "sorted_kineto_cpu_op_ts": [], + "kineto_external_id_to_kineto_op_map": kineto_external_id_to_kineto_op_map, } def calculate_exclusive_dur(self, kineto_tid_cpu_ops_map: Dict[int, List[KinetoOperator]]) -> None: diff --git a/src/trace_link/kineto_operator.py b/src/trace_link/kineto_operator.py index 95bbcede..f074f063 100644 --- a/src/trace_link/kineto_operator.py +++ b/src/trace_link/kineto_operator.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from et_replay.execution_trace import Node as PyTorchOperator @@ -22,6 +22,7 @@ class KinetoOperator: host_op (Optional[PyTorchOperator]): Corresponding PyTorch operator object. parent_host_op_id (Optional[int]): ID of the parent PyTorch operator. inter_thread_dep (Optional[int]): Identifier for inter-thread dependencies. + sync_dep (List[KinetoOperator]): List of KinetoOperator objects that have dependencies on this operator. stream (Optional[int]): CUDA stream identifier associated with the operator. rf_id (Optional[int]): Record function identifier. correlation (int): Identifier used to correlate CUDA runtime and GPU operations. @@ -49,6 +50,7 @@ def __init__(self, kineto_op: Dict[str, Any]) -> None: self.host_op: Optional[PyTorchOperator] = None self.parent_host_op_id: Optional[int] = None self.inter_thread_dep: Optional[int] = None + self.sync_dep: List[KinetoOperator] = [] self.stream: Optional[int] = kineto_op.get("args", {}).get("stream", None) self.rf_id: Optional[int] = kineto_op.get("args", {}).get("Record function id", None) self.correlation: int = kineto_op.get("args", {}).get("correlation", -1) @@ -61,13 +63,14 @@ def __repr__(self) -> str: Returns str: A string representation of the KinetoOperator. """ + sync_dep_ids = [op.id for op in self.sync_dep] return ( f"KinetoOperator(id={self.id}, category={self.category}, name={self.name}, " f"phase={self.phase}, inclusive_dur={self.inclusive_dur}, " f"exclusive_dur={self.exclusive_dur}, timestamp={self.timestamp}, " f"external_id={self.external_id}, ev_idx={self.ev_idx}, tid={self.tid}, " f"parent_host_op_id={self.parent_host_op_id}, inter_thread_dep={self.inter_thread_dep}, " - f"stream={self.stream}, rf_id={self.rf_id}, correlation={self.correlation})" + f"sync_dep={sync_dep_ids}, stream={self.stream}, rf_id={self.rf_id}, correlation={self.correlation})" ) def is_cpu_op(self) -> bool: diff --git a/src/trace_link/trace_link.py b/src/trace_link/trace_link.py index f05b800b..12074df5 100644 --- a/src/trace_link/trace_link.py +++ b/src/trace_link/trace_link.py @@ -18,6 +18,7 @@ def main() -> None: "Merging-PyTorch-and-Kineto-Traces" ) ) + parser.add_argument("--rank", type=int, required=True, help="Rank for the input traces") parser.add_argument( "--chakra-host-trace", type=str, @@ -43,10 +44,11 @@ def main() -> None: logging.basicConfig(level=args.log_level.upper()) linker = TraceLinker() - linker.link(args.chakra_host_trace, args.chakra_device_trace, args.output_file) + linker.link(args.rank, args.chakra_host_trace, args.chakra_device_trace, args.output_file) logging.info(f"Linking process successful. Output file is available at {args.output_file}.") logging.info("Please run the chakra_converter for further postprocessing.") + if __name__ == "__main__": main() diff --git a/src/trace_link/trace_linker.py b/src/trace_link/trace_linker.py index aeeb0358..ade9ac12 100644 --- a/src/trace_link/trace_linker.py +++ b/src/trace_link/trace_linker.py @@ -2,6 +2,7 @@ import copy import json import logging +import os from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Dict, List, Optional, Tuple @@ -10,6 +11,8 @@ EXECUTION_TRACE_THREAD_ANNOTATION, ) from et_replay.execution_trace import Node as PyTorchOperator +from hta.analyzers.critical_path_analysis import CPEdgeType +from hta.trace_analysis import TraceAnalysis from .chakra_device_trace_loader import ChakraDeviceTraceLoader from .chakra_host_trace_loader import ChakraHostTraceLoader @@ -33,11 +36,12 @@ def __init__(self) -> None: self.chakra_device_trace_loader = ChakraDeviceTraceLoader() self.id_assigner = UniqueIdAssigner() - def link(self, chakra_host_trace: str, chakra_device_trace: str, output_file: str) -> None: + def link(self, rank: int, chakra_host_trace: str, chakra_device_trace: str, output_file: str) -> None: """ Links Chakra host execution traces (ET) and Chakra device ET to generate Chakra host + device ET. Args: + rank (int): Rank for the input traces. chakra_host_trace (str): Path to the Chakra host execution trace file. chakra_device_trace (str): Path to the Kineto trace file. output_file (str): Path for the output nyTorch execution trace plus file. @@ -46,6 +50,7 @@ def link(self, chakra_host_trace: str, chakra_device_trace: str, output_file: st ( kineto_cpu_ops, + kineto_tid_ops_map, kineto_tid_cpu_ops_map, kineto_correlation_cuda_runtime_map, kineto_gpu_ops, @@ -57,10 +62,20 @@ def link(self, chakra_host_trace: str, chakra_device_trace: str, output_file: st kineto_rf_id_to_device_op_map, sorted_kineto_cpu_ops, sorted_kineto_cpu_op_ts, + kineto_external_id_to_kineto_op_map, ) = self.chakra_device_trace_loader.load(chakra_device_trace) kineto_tid_cpu_ops_map = self.enforce_inter_thread_order(kineto_tid_cpu_ops_map) + sync_deps = self.load_sync_dependencies(rank, chakra_device_trace) + self.enforce_sync_dep( + kineto_external_id_to_kineto_op_map, + sorted_kineto_cpu_ops, + sorted_kineto_cpu_op_ts, + kineto_tid_ops_map, + sync_deps, + ) + chakra_execution_trace_plus_data = self.link_traces( chakra_host_trace, host_ops, @@ -73,10 +88,66 @@ def link(self, chakra_host_trace: str, chakra_device_trace: str, output_file: st kineto_thread_debug, kineto_process_start_time, kineto_process_end_time, + kineto_external_id_to_kineto_op_map, ) self.dump_chakra_execution_trace_plus(chakra_execution_trace_plus_data, output_file) + def load_sync_dependencies( + self, rank: int, kineto_file: str, annotation: str = "ProfilerStep", instance_id: int = 0 + ) -> Dict[int, List[int]]: + """ + Load synchronization dependencies using Holistic Trace Analysis (HTA). + + Holistic Trace Analysis (HTA) provides various features for trace analysis, one of which is critical path + analysis. This feature identifies dependencies between GPU and CPU operators that are in the critical path. + This method leverages HTA's critical path analysis to determine synchronization points and dependencies, + returning them as a dictionary. + + Args: + rank (int): Rank for the input Kineto trace. + kineto_file (str): Path to the Kineto trace file. + annotation (str): Annotation to use for the analysis. Defaults to "ProfilerStep". + instance_id (int): Instance ID for the analysis. Defaults to 0. + + Returns: + Dict[int, List[int]]: A dictionary mapping end event's external ID to a list of start event's external IDs + that have synchronization dependencies. + """ + sync_dependencies = {} + trace_analysis = TraceAnalysis(trace_dir=os.path.dirname(kineto_file)) + cp_graph, success = trace_analysis.critical_path_analysis( + rank=rank, annotation=annotation, instance_id=instance_id + ) + if not success: + logging.error("Failed to load Critical Path Graph") + return sync_dependencies + + raw_events = trace_analysis.t.get_raw_trace_for_one_rank(rank=rank)["traceEvents"] + for edge in cp_graph.critical_path_edges_set: + if edge.type in [CPEdgeType.SYNC_DEPENDENCY]: + start_event_id, end_event_id = cp_graph.get_events_for_edge(edge) + start_event, end_event = raw_events[start_event_id], raw_events[end_event_id] + if "External id" in end_event["args"] and "External id" in start_event["args"]: + start_event_external_id = start_event["args"]["External id"] + end_event_external_id = end_event["args"]["External id"] + start_event_name = start_event["name"] + end_event_name = end_event["name"] + if start_event_external_id != end_event_external_id: + logging.info( + f"Sync dep: start_event_id {start_event_id}, end_event_id {end_event_id}, " + f"start_ext_id {start_event_external_id}, end_ext_id {end_event_external_id}, " + f"start_event_name '{start_event_name}', end_event_name '{end_event_name}'" + ) + sync_dependencies.setdefault(end_event_external_id, []).append(start_event_external_id) + else: + logging.warning( + f"Synchronization dependency from event {start_event_id} to event {end_event_id} will " + "not be considered due to missing external IDs." + ) + + return sync_dependencies + def enforce_inter_thread_order( self, kineto_tid_cpu_ops_map: Dict[int, List[KinetoOperator]], threshold: int = 1000 ) -> Dict[int, List[KinetoOperator]]: @@ -184,6 +255,123 @@ def find_last_cpu_node_before_timestamp( logging.debug(f"Last CPU node before timestamp {timestamp} found: {last_cpu_node}") return last_cpu_node_rf_id + def enforce_sync_dep( + self, + kineto_external_id_to_kineto_op_map: Dict[int, KinetoOperator], + sorted_kineto_cpu_ops: List[KinetoOperator], + sorted_kineto_cpu_op_ts: List[int], + kineto_tid_ops_map: Dict[int, List[KinetoOperator]], + sync_deps: Dict[int, List[int]], + ): + """ + Enforces synchronization order by storing Kineto ops that have synchronization dependency. + + Args: + kineto_external_id_to_kineto_op_map (Dict[int, KinetoOperator]): Mapping between external ID and Kineto + operators. + sorted_kineto_cpu_ops (List[KinetoOperator]): Sorted list of Kineto CPU operators. + sorted_kineto_cpu_op_ts (List[int]): Sorted list of timestamps for the Kineto CPU operators. + kineto_tid_ops_map (Dict[int, List[KinetoOperator]]): Kineto operators grouped by thread ID. + sync_deps (Dict[int, List[int]]): A dictionary mapping end event's external ID to a list of start event's + external IDs that have synchronization dependencies. + """ + logging.info("Enforcing sync order in Kineto traces.") + + with ThreadPoolExecutor() as executor: + futures = { + executor.submit( + self.process_thread_sync_dep, + kineto_external_id_to_kineto_op_map, + sorted_kineto_cpu_ops, + sorted_kineto_cpu_op_ts, + tid, + ops, + sync_deps, + ): tid + for tid, ops in kineto_tid_ops_map.items() + } + + for future in as_completed(futures): + tid = futures[future] + future.result() + logging.debug(f"Thread {tid} sync dependencies processed.") + + def process_thread_sync_dep( + self, + kineto_external_id_to_kineto_op_map: Dict[int, KinetoOperator], + sorted_kineto_cpu_ops: List[KinetoOperator], + sorted_kineto_cpu_op_ts: List[int], + tid: int, + ops: List[KinetoOperator], + sync_deps: Dict[int, List[int]], + ) -> None: + """ + Process synchronization dependencies for a specific thread. + + This method identifies synchronization dependencies for each operator within the current thread + and updates the `sync_dep` attribute of each operator accordingly. + + Args: + kineto_external_id_to_kineto_op_map (Dict[int, KinetoOperator]): Mapping between external ID and Kineto + operators. + sorted_kineto_cpu_ops (List[KinetoOperator]): Sorted list of Kineto CPU operators. + sorted_kineto_cpu_op_ts (List[int]): Sorted list of timestamps for the Kineto CPU operators. + tid (int): The current thread ID being processed. + ops (List[KinetoOperator]): Kineto operators. + sync_deps (Dict[int, List[int]]): A dictionary mapping end event's external ID to a list of start event's + external IDs that have synchronization dependencies. + """ + logging.info(f"Thread {tid}: Identifying synchronization dependency.") + for op in ops: + if op.external_id in sync_deps: + sync_start_external_ids = sync_deps[op.external_id] + + for external_id in sync_start_external_ids: + if external_id in kineto_external_id_to_kineto_op_map: + start_sync_op = kineto_external_id_to_kineto_op_map[external_id] + + # Find the closest Kineto operator with a start time later than the current op's timestamp + closest_start_kineto_op = self.find_closest_start_kineto_op( + op, sorted_kineto_cpu_ops, sorted_kineto_cpu_op_ts + ) + + # Add the external ID of the start_sync_op to closest_start_kineto_op.sync_dep if not present + if (closest_start_kineto_op is not None) and ( + start_sync_op not in closest_start_kineto_op.sync_dep + ): + start_sync_op.sync_dep.append(closest_start_kineto_op) + logging.info( + f"Sync dependency: end op {closest_start_kineto_op.name} " + f"(external_id: {closest_start_kineto_op.external_id}, " + f"timestamp: {closest_start_kineto_op.timestamp})" + f" -> start op {start_sync_op.name} (external_id: {start_sync_op.external_id})" + ) + + def find_closest_start_kineto_op( + self, op: KinetoOperator, sorted_kineto_cpu_ops: List[KinetoOperator], sorted_kineto_cpu_op_ts: List[int] + ) -> Optional[KinetoOperator]: + """ + Find the closest start Kineto operator that occurs after the given operator's timestamp. + + Args: + op (KinetoOperator): The current Kineto operator. + sorted_kineto_cpu_ops (List[KinetoOperator]): Sorted list of Kineto CPU operators. + sorted_kineto_cpu_op_ts (List[int]): Sorted list of timestamps for the Kineto CPU operators. + + Returns: + Optional[KinetoOperator]: The closest start Kineto operator found, or None if not found. + """ + index = bisect.bisect_right(sorted_kineto_cpu_op_ts, op.timestamp) + closest_start_kineto_op = None + + for i in range(index, len(sorted_kineto_cpu_op_ts)): + potential_sync_op = sorted_kineto_cpu_ops[i] + if potential_sync_op.timestamp > op.timestamp: + closest_start_kineto_op = potential_sync_op + break + + return closest_start_kineto_op + def link_traces( self, chakra_host_trace: str, @@ -197,6 +385,7 @@ def link_traces( kineto_thread_debug: Dict[int, Tuple[int, int]], kineto_process_start_time: int, kineto_process_end_time: int, + kineto_external_id_to_kineto_op_map: Dict[int, KinetoOperator], ) -> Dict: """ Link Chakra Host ET and Chakra Device ET to produce an enhanced Chakra ET (ET +). @@ -215,6 +404,8 @@ def link_traces( of start and end times. kineto_process_start_time (int): Start time of the process, based on the earliest operator timestamp. kineto_process_end_time (int): End time of the process, based on the latest operator timestamp. + kineto_external_id_to_kineto_op_map (Dict[int, KinetoOperator]): Mapping between external ID and Kineto + operators. Returns: Dict: The enhanced Chakra Host Execution Trace (ET+). @@ -246,6 +437,7 @@ def link_traces( kineto_correlation_cuda_runtime_map, kineto_rf_id_to_device_op_map, kineto_gpu_ops, + kineto_external_id_to_kineto_op_map, ) chakra_execution_trace_plus_data = self.construct_et_plus_data( chakra_host_trace, @@ -337,6 +529,7 @@ def map_host_to_device_ops( kineto_correlation_cuda_runtime_map: Dict[int, KinetoOperator], kineto_rf_id_to_device_op_map: Dict[int, KinetoOperator], kineto_gpu_ops: List[KinetoOperator], + kineto_external_id_to_kineto_op_map, ) -> Tuple[ Dict[int, List[KinetoOperator]], Dict[int, int], @@ -376,6 +569,7 @@ def map_host_to_device_ops( kineto_op, cpu_ev_idx_to_gpu_ops_map, kineto_rf_id_to_device_op_map, + kineto_external_id_to_kineto_op_map, ) logging.debug("Completed mapping of Chakra host operators to Kineto operators.") @@ -560,6 +754,7 @@ def link_ops( kineto_op: KinetoOperator, cpu_ev_idx_to_gpu_ops_map: Dict[int, List[KinetoOperator]], kineto_rf_id_to_device_op_map: Dict[int, KinetoOperator], + kineto_external_id_to_kineto_op_map: Dict[int, KinetoOperator], ) -> Tuple[List[KinetoOperator], int, int, int, Optional[int]]: """ Link a Chakra host operator to its corresponding Kineto operator and any associated GPU operators. @@ -569,6 +764,8 @@ def link_ops( kineto_op (KinetoOperator): Corresponding Kineto operator. cpu_ev_idx_to_gpu_ops_map (Dict[int, List[KinetoOperator]]): GPU ops mapping. kineto_rf_id_to_device_op_map (Dict[int, KinetoOperator]): Kineto operator mapping. + kineto_external_id_to_kineto_op_map (Dict[int, KinetoOperator]): Mapping from external id to + KinetoOperators. Returns: Tuple containing: @@ -577,6 +774,7 @@ def link_ops( - int: The exclusive duration of the linked Kineto operator. - int: The timestamp of the linked Kineto operator. - Optional[int]: The inter-thread dependency ID if present. + - List[int]: List of synchronization dependency IDs. """ kineto_op.host_op = host_op linked_gpu_ops = cpu_ev_idx_to_gpu_ops_map.get(kineto_op.ev_idx, []) @@ -665,9 +863,22 @@ def construct_et_plus_data( ) pytorch_et_data["nodes"] += gpu_ops + # Add sync dependencies + sync_dep_mapping = {} + for gpu_op in gpu_ops: + if "sync_dep_to" in gpu_op: + for sync_dep_to in gpu_op["sync_dep_to"]: + if sync_dep_to not in sync_dep_mapping: + sync_dep_mapping[sync_dep_to] = [] + sync_dep_mapping[sync_dep_to].append(gpu_op["id"]) + del gpu_op["sync_dep_to"] + # Update parent-child relationships with new IDs sorted_nodes = sorted(pytorch_et_data["nodes"], key=lambda x: x["id"]) for op in sorted_nodes: + for key in sync_dep_mapping: + if self.id_assigner.lookup_new_id(key) == op["id"]: + op["sync_dep"] = sync_dep_mapping[key] if "ctrl_deps" in op: op["ctrl_deps"] = self.id_assigner.assign_or_retrieve_id(op["ctrl_deps"]) @@ -762,9 +973,15 @@ def process_dependent_gpu_ops( ), } ) - updated_gpu_ops.append(new_gpu_op) + for sync_dep in gpu_op.sync_dep: + if sync_dep.host_op: + if "sync_dep_to" not in new_gpu_op: + new_gpu_op["sync_dep_to"] = [] + if self.id_assigner.lookup_new_id(sync_dep.host_op.id) not in new_gpu_op["sync_dep_to"]: + new_gpu_op["sync_dep_to"].append(self.id_assigner.lookup_new_id(sync_dep.host_op.id)) + return updated_gpu_ops def dump_chakra_execution_trace_plus(self, chakra_execution_trace_plus_data: Dict, output_file: str) -> None: diff --git a/tests/trace_link/test_kineto_operator.py b/tests/trace_link/test_kineto_operator.py index f72d9dc2..772f8852 100644 --- a/tests/trace_link/test_kineto_operator.py +++ b/tests/trace_link/test_kineto_operator.py @@ -43,7 +43,7 @@ def test_repr_method(sample_operator_data): expected_repr = ( "KinetoOperator(id=None, category=Kernel, name=cudaLaunchKernel, phase=X, " "inclusive_dur=100, exclusive_dur=100, timestamp=1590000000, external_id=123, ev_idx=456, " - "tid=1234, parent_host_op_id=None, inter_thread_dep=None, stream=7, rf_id=12, " + "tid=1234, parent_host_op_id=None, inter_thread_dep=None, sync_dep=[], stream=7, rf_id=12, " "correlation=99)" ) assert repr(operator) == expected_repr diff --git a/tests/trace_link/test_trace_linker.py b/tests/trace_link/test_trace_linker.py index 4d320acd..a0441ae4 100644 --- a/tests/trace_link/test_trace_linker.py +++ b/tests/trace_link/test_trace_linker.py @@ -9,6 +9,7 @@ EXECUTION_TRACE_THREAD_ANNOTATION, ) from et_replay.execution_trace import Node as PyTorchOperator +from hta.analyzers.critical_path_analysis import CPEdgeType, CPGraph @pytest.fixture @@ -38,6 +39,7 @@ def test_link_traces(mock_map_ops, mock_add_annotations, mock_construct_et_plus, kineto_thread_info = {1: (100, 200)} kineto_process_start_time = 50 kineto_process_end_time = 300 + kineto_external_id_to_kineto_op_map = {1: MagicMock(spec=KinetoOperator)} trace_linker.link_traces( "pytorch_et_file", @@ -51,6 +53,7 @@ def test_link_traces(mock_map_ops, mock_add_annotations, mock_construct_et_plus, kineto_thread_info, kineto_process_start_time, kineto_process_end_time, + kineto_external_id_to_kineto_op_map, ) mock_add_annotations.assert_called_once() @@ -90,6 +93,189 @@ def test_enforce_inter_thread_order_exception(mock_process_thread, mock_future_r trace_linker.enforce_inter_thread_order(kineto_tid_cpu_ops_map) +@pytest.mark.parametrize( + "trace_events, critical_path_edges, expected_result", + [ + ( + {0: {"args": {"External id": 1}, "name": "event_0"}, 1: {"args": {"External id": 2}, "name": "event_1"}}, + [(0, 1)], + {2: [1]}, + ), + ( + { + 0: {"args": {"External id": 1}, "name": "event_0"}, + 1: {"args": {"External id": 2}, "name": "event_1"}, + 2: {"args": {"External id": 3}, "name": "event_2"}, + }, + [(0, 2), (1, 2)], + {3: [1, 2]}, + ), + ( + { + 0: {"args": {"External id": 1}, "name": "event_0"}, + 1: {"args": {"External id": 2}, "name": "event_1"}, + 2: {"args": {}, "name": "event_2"}, + }, + [(0, 1), (1, 2)], + {2: [1]}, + ), + ], +) +@patch("hta.trace_analysis.TraceAnalysis.critical_path_analysis") +@patch("hta.trace_analysis.Trace") +def test_load_sync_dependencies_success( + mock_trace, mock_critical_path_analysis, trace_linker, trace_events, critical_path_edges, expected_result +): + # Mock the Trace instance and its methods + mock_trace_instance = mock_trace.return_value + mock_trace_instance.is_parsed = True + mock_trace_instance.get_raw_trace_for_one_rank.return_value = {"traceEvents": trace_events} + mock_cp_graph = MagicMock(spec=CPGraph) + mock_cp_graph.critical_path_edges_set = [MagicMock(type=CPEdgeType.SYNC_DEPENDENCY) for _ in critical_path_edges] + mock_cp_graph.get_events_for_edge.side_effect = critical_path_edges + mock_critical_path_analysis.return_value = (mock_cp_graph, True) + + # Call the method + result = trace_linker.load_sync_dependencies(0, "kineto_file.json", "ProfilerStep", 0) + + # Assert the expected result + assert result == expected_result + + +@patch("hta.trace_analysis.TraceAnalysis.critical_path_analysis") +@patch("hta.trace_analysis.Trace") +def test_load_sync_dependencies_failure(mock_trace, mock_critical_path_analysis, trace_linker): + # Mock the Trace instance and its methods to return failure + mock_trace_instance = mock_trace.return_value + mock_trace_instance.is_parsed = True + mock_cp_graph = MagicMock(spec=CPGraph) + mock_cp_graph.critical_path_edges_set = [] + mock_critical_path_analysis.return_value = (mock_cp_graph, False) + + # Call the method + result = trace_linker.load_sync_dependencies(0, "kineto_file.json", "ProfilerStep", 0) + + # Assert the expected result + assert result == {} + + +@patch("hta.trace_analysis.TraceAnalysis.critical_path_analysis") +@patch("hta.trace_analysis.Trace") +def test_load_sync_dependencies_missing_external_id(mock_trace, mock_critical_path_analysis, trace_linker): + # Mock the Trace instance and its methods + mock_trace_instance = mock_trace.return_value + mock_trace_instance.is_parsed = True + mock_trace_instance.get_raw_trace_for_one_rank.return_value = { + "traceEvents": {0: {"args": {"External id": 1}}, 1: {"args": {}}} + } + mock_cp_graph = MagicMock(spec=CPGraph) + mock_cp_graph.critical_path_edges_set = [MagicMock(type=CPEdgeType.SYNC_DEPENDENCY)] + mock_cp_graph.get_events_for_edge.return_value = (0, 1) + mock_critical_path_analysis.return_value = (mock_cp_graph, True) + + # Call the method + result = trace_linker.load_sync_dependencies(0, "kineto_file.json", "ProfilerStep", 0) + + # Assert the expected result + assert result == {} + + +@pytest.mark.parametrize( + "sync_deps, current_tid, ops_by_tid, expected_sync_deps", + [ + ( + {1: [10, 20], 2: [30, 40]}, + 1, + { + 1: [KinetoOperator({"external_id": 1}), KinetoOperator({"external_id": 2})], + 2: [KinetoOperator({"external_id": 3})], + }, + {1: [10, 20], 2: [30, 40]}, + ), + ( + {1: [10], 2: [20, 30], 3: [40]}, + 2, + { + 1: [KinetoOperator({"external_id": 1})], + 2: [KinetoOperator({"external_id": 2}), KinetoOperator({"external_id": 3})], + }, + {2: [20, 30], 3: [40]}, + ), + ( + {}, + 1, + { + 1: [KinetoOperator({"external_id": 1}), KinetoOperator({"external_id": 2})], + }, + {}, + ), + ], +) +def test_process_thread_sync_dep(sync_deps, current_tid, ops_by_tid, expected_sync_deps, trace_linker): + trace_linker.process_thread_sync_dep({}, ops_by_tid, [], current_tid, ops_by_tid[current_tid], sync_deps) + + for op in ops_by_tid.get(current_tid, []): + assert op.sync_dep == expected_sync_deps.get(op.external_id, []) + + +@pytest.mark.parametrize( + "op, sorted_kineto_cpu_ops, sorted_kineto_cpu_op_ts, expected_result", + [ + # Case 1: One operator after the given timestamp + ( + MagicMock(spec=KinetoOperator, timestamp=100), + [MagicMock(spec=KinetoOperator, timestamp=50), MagicMock(spec=KinetoOperator, timestamp=150)], + [50, 150], + 1, # index of the expected result in sorted_kineto_cpu_ops + ), + # Case 2: Multiple operators after the given timestamp + ( + MagicMock(spec=KinetoOperator, timestamp=100), + [ + MagicMock(spec=KinetoOperator, timestamp=50), + MagicMock(spec=KinetoOperator, timestamp=150), + MagicMock(spec=KinetoOperator, timestamp=200), + ], + [50, 150, 200], + 1, # index of the expected result in sorted_kineto_cpu_ops + ), + # Case 3: No operators after the given timestamp + ( + MagicMock(spec=KinetoOperator, timestamp=100), + [MagicMock(spec=KinetoOperator, timestamp=50), MagicMock(spec=KinetoOperator, timestamp=75)], + [50, 75], + None, # No result expected + ), + # Case 4: Operator with exact timestamp + ( + MagicMock(spec=KinetoOperator, timestamp=100), + [ + MagicMock(spec=KinetoOperator, timestamp=50), + MagicMock(spec=KinetoOperator, timestamp=100), + MagicMock(spec=KinetoOperator, timestamp=150), + ], + [50, 100, 150], + 2, # index of the expected result in sorted_kineto_cpu_ops + ), + # Case 5: Empty list of operators + ( + MagicMock(spec=KinetoOperator, timestamp=100), + [], + [], + None, # No result expected + ), + ], +) +def test_find_closest_start_kineto_op( + op, sorted_kineto_cpu_ops, sorted_kineto_cpu_op_ts, expected_result, trace_linker +): + result = trace_linker.find_closest_start_kineto_op(op, sorted_kineto_cpu_ops, sorted_kineto_cpu_op_ts) + if expected_result is not None: + assert result == sorted_kineto_cpu_ops[expected_result] + else: + assert result is None + + @pytest.mark.parametrize( "ops_by_tid, exclude_tid, timestamp, expected_result", [ @@ -307,12 +493,17 @@ def test_link_ops( cpu_ev_idx_to_gpu_ops_map = {kineto_op.ev_idx: expected_linked_gpu_ops} kineto_rf_id_to_kineto_op_map = {1: MagicMock(spec=KinetoOperator, host_op=MagicMock(id=42))} + kineto_external_id_to_kineto_op_map = { + 2: MagicMock(spec=KinetoOperator, host_op=MagicMock(id=3)), + 3: MagicMock(spec=KinetoOperator, host_op=MagicMock(id=4)), + } result = trace_linker.link_ops( host_op, kineto_op, cpu_ev_idx_to_gpu_ops_map, kineto_rf_id_to_kineto_op_map, + kineto_external_id_to_kineto_op_map, ) assert result == ( @@ -335,16 +526,19 @@ def test_link_ops_with_no_gpu_ops(trace_linker): timestamp=123456, host_op=None, inter_thread_dep=None, + sync_dep=[], ) cpu_ev_idx_to_gpu_ops_map = {} kineto_rf_id_to_kineto_op_map = {} + kineto_external_id_to_kineto_op_map = {} result = trace_linker.link_ops( host_op, kineto_op, cpu_ev_idx_to_gpu_ops_map, kineto_rf_id_to_kineto_op_map, + kineto_external_id_to_kineto_op_map, ) assert result == ([], 100, 50, 123456, None) @@ -470,6 +664,7 @@ def test_process_dependent_gpu_ops(trace_linker, orig_op_id, cpu_op, kineto_gpu_ gpu_op.exclusive_dur = gpu_op_data["exclusive_dur"] gpu_op.stream = gpu_op_data["stream"] gpu_op.pg_name = gpu_op_data.get("pg_name", None) + gpu_op.sync_dep = [] kineto_gpu_op_objects.append(gpu_op) host_op_id_to_kineto_ops_map = {orig_op_id: kineto_gpu_op_objects} From afe6713faa72b657812e3e76dd13c7068aafc426 Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Mon, 10 Jun 2024 20:03:53 -0400 Subject: [PATCH 2/5] Add synchronization dependency from GPU to CPU operators in PyTorch converter Co-authored-by: Joongun Park <8554137+JoongunPark@users.noreply.github.com> --- src/converter/pytorch_converter.py | 9 +++++++++ src/converter/pytorch_node.py | 1 + 2 files changed, 10 insertions(+) diff --git a/src/converter/pytorch_converter.py b/src/converter/pytorch_converter.py index 9474c556..ea383a51 100644 --- a/src/converter/pytorch_converter.py +++ b/src/converter/pytorch_converter.py @@ -457,6 +457,15 @@ def convert_ctrl_dep_to_data_dep( last_visited_non_gpu = current_node last_visited_any = current_node + if json_node.sync_dep: + for sync_dep in json_node.sync_dep: + if sync_dep not in current_node.data_deps: + current_node.data_deps.append(sync_dep) + logging.info( + f"Node ID {current_node.id} now has an synchonization dependency on Node ID {sync_dep}" + ) + + # Add children to the stack children_chakra_ids = [child.id for child in json_node.children] for child_chakra_id in sorted(children_chakra_ids, reverse=True): child_chakra_node = protobuf_node_map.get(child_chakra_id) diff --git a/src/converter/pytorch_node.py b/src/converter/pytorch_node.py index 50feb4ac..86b59acc 100644 --- a/src/converter/pytorch_node.py +++ b/src/converter/pytorch_node.py @@ -110,6 +110,7 @@ def _parse_data_1_0_3_chakra_0_0_4(self, node_data: Dict[str, Any]) -> None: self.exclusive_dur = node_data.get("exclusive_dur", 0) self.ts = node_data.get("ts") self.inter_thread_dep = node_data.get("inter_thread_dep") + self.sync_dep = node_data.get("sync_dep") self.cat = node_data.get("cat") self.stream = node_data.get("stream", 0) # In Colletive comms nodes, pg_name is in node_data if exists. From 8f3746288658bfc8fcb569997317d4adfc746fa7 Mon Sep 17 00:00:00 2001 From: JoongunPark <8554137+JoongunPark@users.noreply.github.com> Date: Fri, 10 May 2024 18:00:36 -0400 Subject: [PATCH 3/5] Update user guide to include HTA installation and new argument for chakra_trace_link Co-authored-by: Joongun Park <8554137+JoongunPark@users.noreply.github.com> --- USER_GUIDE.md | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/USER_GUIDE.md b/USER_GUIDE.md index 916c0714..164aaeb6 100644 --- a/USER_GUIDE.md +++ b/USER_GUIDE.md @@ -35,7 +35,19 @@ $ git checkout 7b19f586dd8b267333114992833a0d7e0d601630 $ pip install . ``` -### Step 4: Uninstalling Chakra +### Step 4: Install Holistic Trace Analysis +Installing Holistic Trace Analysis is necessary for Trace link. + +```bash +$ git clone https://github.com/facebookresearch/HolisticTraceAnalysis.git +$ cd HolisticTraceAnalysis +$ git checkout d731cc2e2249976c97129d409a83bd53d93051f6 +$ git submodule update --init +$ pip install -r requirements.txt +$ pip install -e . +``` + +### Step 5: Uninstalling Chakra To uninstall Chakra, use the following command within the virtual environment. ```bash @@ -49,6 +61,7 @@ Merge Chakra host execution trace and Chakra device execution trace to encode GP $ chakra_trace_link \ --chakra-host-trace /path/to/chakra_host_trace \ --chakra-device-trace /path/to/chakra_device_trace \ + --rank [RANK] \ --output-file /path/to/chakra_host_device_trace.json ``` From bece42fc66cea67323dfb5f6b5c65cd48d3c6fb0 Mon Sep 17 00:00:00 2001 From: JoongunPark <8554137+JoongunPark@users.noreply.github.com> Date: Fri, 10 May 2024 18:27:39 -0400 Subject: [PATCH 4/5] Add HTA installation step in end_to_end_tests workflow Co-authored-by: Joongun Park <8554137+JoongunPark@users.noreply.github.com> --- .github/workflows/end_to_end_tests.yml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.github/workflows/end_to_end_tests.yml b/.github/workflows/end_to_end_tests.yml index 4b8edb2b..fcddc117 100644 --- a/.github/workflows/end_to_end_tests.yml +++ b/.github/workflows/end_to_end_tests.yml @@ -26,6 +26,15 @@ jobs: git checkout 7b19f586dd8b267333114992833a0d7e0d601630 pip install . + - name: Install HTA + run: | + git clone https://github.com/facebookresearch/HolisticTraceAnalysis.git + cd HolisticTraceAnalysis + git checkout d731cc2e2249976c97129d409a83bd53d93051f6 + git submodule update --init + pip install -r requirements.txt + pip install -e . + - name: Test chakra_trace_link Without Arguments run: | chakra_trace_link || [ $? -eq 2 ] From 6bdf24e02db4d6924faae1ec0a06f5804c24c676 Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Wed, 3 Jul 2024 15:10:41 -0400 Subject: [PATCH 5/5] Resolve kineto_file path bug with absolute path conversion --- src/trace_link/trace_linker.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/trace_link/trace_linker.py b/src/trace_link/trace_linker.py index ade9ac12..71696ca8 100644 --- a/src/trace_link/trace_linker.py +++ b/src/trace_link/trace_linker.py @@ -115,7 +115,9 @@ def load_sync_dependencies( that have synchronization dependencies. """ sync_dependencies = {} - trace_analysis = TraceAnalysis(trace_dir=os.path.dirname(kineto_file)) + absolute_kineto_file = os.path.abspath(kineto_file) + trace_dir = os.path.dirname(absolute_kineto_file) + trace_analysis = TraceAnalysis(trace_dir=trace_dir) cp_graph, success = trace_analysis.critical_path_analysis( rank=rank, annotation=annotation, instance_id=instance_id )