Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
TaekyungHeo committed Oct 14, 2024
1 parent b33b868 commit 8e3e909
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 4 deletions.
19 changes: 19 additions & 0 deletions conf/common/test_scenario/nccl_test.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,25 @@
# limitations under the License.

name = "nccl-test"

[prologue]
[[prologue.task]]
name = "setup_environment"
command = "/path/to/setup_script.sh"

[[prologue.task]]
name = "load_modules"
command = "module load nccl"

[epilogue]
[[epilogue.task]]
name = "cleanup_environment"
command = "/path/to/cleanup_script.sh"

[[epilogue.task]]
name = "unload_modules"
command = "module unload nccl"

[[Tests]]
id = "Tests.1"
test_name = "nccl_test_all_reduce"
Expand Down
2 changes: 2 additions & 0 deletions src/cloudai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ._core.job_status_result import JobStatusResult
from ._core.job_status_retrieval_strategy import JobStatusRetrievalStrategy
from ._core.json_gen_strategy import JsonGenStrategy
from ._core.plugin import Plugin
from ._core.registry import Registry
from ._core.report_generation_strategy import ReportGenerationStrategy
from ._core.runner import Runner
Expand Down Expand Up @@ -206,6 +207,7 @@
"JobStatusResult",
"JsonGenStrategy",
"Parser",
"Plugin",
"ReportGenerationStrategy",
"ReportGenerator",
"Runner",
Expand Down
43 changes: 43 additions & 0 deletions src/cloudai/_core/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


class Plugin:
"""
A class to represent plugins actions in a test scenario.
Attributes
name (str): Name of the plugin task.
command (str): Shell command to be executed in the plugin.
"""

def __init__(self, name: str, command: str) -> None:
"""
Initialize a Plugin instance.
Args:
name (str): Name of the plugin task.
command (str): Command to execute as part of the plugin.
"""
self.name = name
self.command = command

def __repr__(self) -> str:
"""Return a string representation of the Plugin instance."""
return f"Plugin(name={self.name}, command={self.command})"

def run(self) -> None:
"""Execute the command in the plugin."""
print(f"Executing command '{self.command}' for task '{self.name}'")
3 changes: 3 additions & 0 deletions src/cloudai/_core/test_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

if TYPE_CHECKING:
from .test import Test
from .plugin import Plugin


class TestDependency:
Expand Down Expand Up @@ -61,6 +62,8 @@ class TestRun:
weight: float = 0.0
ideal_perf: float = 1.0
dependencies: dict[str, TestDependency] = field(default_factory=dict)
prologue: List[Plugin] = field(default_factory=list)
epilogue: List[Plugin] = field(default_factory=list)

def __hash__(self) -> int:
return hash(self.name + self.test.name + str(self.iterations) + str(self.current_iteration))
Expand Down
37 changes: 34 additions & 3 deletions src/cloudai/_core/test_scenario_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
# limitations under the License.

import logging
from typing import Any, Dict, Literal, Optional
from typing import Any, Dict, List, Literal, Optional

import toml
from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator

from .exceptions import TestScenarioParsingError, format_validation_error
from .plugin import Plugin
from .test import Test
from .test_scenario import TestDependency, TestRun, TestScenario

Expand Down Expand Up @@ -136,8 +137,10 @@ def _parse_data(self, data: Dict[str, Any]) -> TestScenario:
total_weight = sum(tr.weight for tr in ts_model.tests)
normalized_weight = 0 if total_weight == 0 else 100 / total_weight

prologue = self._parse_plugins(data.get("prologue", {}))
epilogue = self._parse_plugins(data.get("epilogue", {}))
testruns_by_id: dict[str, TestRun] = {
tr.id: self._create_section_test_run(tr, normalized_weight) for tr in ts_model.tests
tr.id: self._create_section_test_run(tr, normalized_weight, prologue, epilogue) for tr in ts_model.tests
}

tests_data: dict[str, _TestRunTOML] = {tr.id: tr for tr in ts_model.tests}
Expand All @@ -154,13 +157,21 @@ def _parse_data(self, data: Dict[str, Any]) -> TestScenario:
job_status_check=ts_model.job_status_check,
)

def _create_section_test_run(self, test_info: _TestRunTOML, normalized_weight: float) -> TestRun:
def _create_section_test_run(
self,
test_info: _TestRunTOML,
normalized_weight: float,
prologue: Optional[List[Plugin]],
epilogue: Optional[List[Plugin]],
) -> TestRun:
"""
Create a section-specific Test object by copying from the test mapping.
Args:
test_info (Dict[str, Any]): Information of the test.
normalized_weight (float): Normalized weight for the test.
prologue (Optional[List[Plugin]]): List of plugins to execute before each test.
epilogue (Optional[List[Plugin]]): List of plugins to execute after each test.
Returns:
Test: Copied and updated Test object for the section.
Expand Down Expand Up @@ -193,5 +204,25 @@ def _create_section_test_run(self, test_info: _TestRunTOML, normalized_weight: f
sol=test_info.sol,
weight=test_info.weight * normalized_weight,
ideal_perf=test_info.ideal_perf,
prologue=prologue if prologue is not None else [],
epilogue=epilogue if epilogue is not None else [],
)
return tr

def _parse_plugins(self, plugin_data: Dict[str, Any]) -> List[Plugin]:
"""
Parse the prologue or epilogue plugin data.
Args:
plugin_data (Dict[str, Any]): Dictionary containing the plugin data.
Returns:
List[Plugin]: List of parsed Plugin objects.
"""
plugins = []
tasks = plugin_data.get("task", [])
for task in tasks:
name = task.get("name", "")
command = task.get("command", "")
plugins.append(Plugin(name=name, command=command))
return plugins
13 changes: 12 additions & 1 deletion src/cloudai/systems/slurm/strategy/slurm_command_gen_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,18 @@ def gen_exec_command(self, tr: TestRun) -> str:
slurm_args = self._parse_slurm_args(
tr.test.test_template.__class__.__name__, env_vars, cmd_args, tr.num_nodes, tr.nodes
)
srun_command = self.generate_srun_command(slurm_args, env_vars, cmd_args, tr.test.extra_cmd_args)

srun_command = ""
if tr.prologue:
for plugin in tr.prologue:
srun_command += plugin.command + "\n"

srun_command += self.generate_srun_command(slurm_args, env_vars, cmd_args, tr.test.extra_cmd_args)

if tr.epilogue:
for plugin in tr.epilogue:
srun_command += plugin.command + "\n"

return self._write_sbatch_script(slurm_args, env_vars, srun_command, tr.output_path)

def generate_test_command(
Expand Down
109 changes: 109 additions & 0 deletions tests/slurm_command_gen_strategy/test_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import Generator
from unittest.mock import Mock, patch

import pytest
from cloudai import Plugin, TestRun
from cloudai.systems.slurm.strategy import SlurmCommandGenStrategy


@pytest.fixture
def command_gen_strategy(slurm_system: Mock) -> Generator[SlurmCommandGenStrategy, None, None]:
cmd_args = {"test_arg": "test_value"}
strategy = SlurmCommandGenStrategy(slurm_system, cmd_args)

with patch.object(strategy, "generate_test_command", return_value="nccl-test-command"):
yield strategy


@pytest.fixture
def testrun(tmp_path: Path) -> TestRun:
mock_test_definition = Mock()
mock_test_template = Mock()

mock_test_definition.name = "test_job"
mock_test_definition.extra_cmd_args = ""
mock_test_template.name = "test_template"

test = Mock()
test.test_definition = mock_test_definition
test.test_template = mock_test_template
test.extra_cmd_args = ""

return TestRun(
name="test_job",
test=test,
output_path=tmp_path,
num_nodes=2,
nodes=["node1", "node2"],
prologue=[],
epilogue=[],
)


@pytest.mark.parametrize(
"name,command",
[
("setup_environment", "module load nccl"),
("cleanup_environment", "module unload nccl"),
],
)
def test_plugin_initialization(name: str, command: str) -> None:
plugin = Plugin(name=name, command=command)
assert plugin.name == name
assert plugin.command == command


@pytest.mark.parametrize(
"name,command,expected_repr",
[
("setup_environment", "module load nccl", "Plugin(name=setup_environment, command=module load nccl)"),
("cleanup_environment", "module unload nccl", "Plugin(name=cleanup_environment, command=module unload nccl)"),
],
)
def test_plugin_repr(name: str, command: str, expected_repr: str) -> None:
plugin = Plugin(name=name, command=command)
assert repr(plugin) == expected_repr


@patch("builtins.print")
@pytest.mark.parametrize(
"name,command,expected_output",
[
("setup_environment", "module load nccl", "Executing command 'module load nccl' for task 'setup_environment'"),
(
"cleanup_environment",
"module unload nccl",
"Executing command 'module unload nccl' for task 'cleanup_environment'",
),
],
)
def test_plugin_run(mock_print: Mock, name: str, command: str, expected_output: str) -> None:
plugin = Plugin(name=name, command=command)
plugin.run()
mock_print.assert_called_once_with(expected_output)


def test_generate_srun_command_no_plugins(command_gen_strategy: SlurmCommandGenStrategy, testrun: TestRun) -> None:
srun_command = command_gen_strategy.generate_test_command(
{},
{"test_arg": "test_value"},
testrun.test.extra_cmd_args,
)
assert srun_command == "nccl-test-command"

0 comments on commit 8e3e909

Please sign in to comment.