Skip to content

Commit

Permalink
Refactor parser to remove explicit plugin_path argument, use default …
Browse files Browse the repository at this point in the history
…Path
  • Loading branch information
TaekyungHeo committed Oct 29, 2024
1 parent 9511e2c commit ae22fc1
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/cloudai/cli/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def handle_dry_run_and_run(args: argparse.Namespace) -> int:
args (argparse.Namespace): The parsed command-line arguments.
"""
parser = Parser(args.system_config)
system, tests, test_scenario = parser.parse(args.tests_dir, args.test_scenario, Path("conf/common/plugin"))
system, tests, test_scenario = parser.parse(args.tests_dir, args.test_scenario)
assert test_scenario is not None

if args.output_dir:
Expand Down
8 changes: 3 additions & 5 deletions src/cloudai/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def parse(
self,
test_path: Path,
test_scenario_path: Optional[Path] = None,
plugin_path: Optional[Path] = None,
) -> Tuple[System, List[Test], Optional[TestScenario]]:
"""
Parse configurations for system, test templates, and test scenarios.
Expand All @@ -61,7 +60,6 @@ def parse(
test_path (Path): The file path for tests.
test_scenario_path (Optional[Path]): The file path for the main test scenario.
If None, all tests are included.
plugin_path (Optional[Path]): The base file path for plugin-specific tests and scenarios.
Returns:
Tuple[System, List[Test], Optional[TestScenario]]: A tuple containing the system object, a list of filtered
Expand All @@ -80,8 +78,8 @@ def parse(
except TestConfigParsingError:
exit(1) # exit right away to keep error message readable for users

plugin_test_scenario_path = plugin_path
plugin_test_path = plugin_path / "test" if plugin_path else None
plugin_test_scenario_path = Path("conf/common/plugin")
plugin_test_path = Path("conf/common/plugin/test")

plugin_tests = (
self.parse_tests(list(plugin_test_path.glob("*.toml")), system)
Expand All @@ -92,7 +90,7 @@ def parse(
if test_scenario_path:
return self._parse_with_scenario(system, tests, test_scenario_path, plugin_tests, plugin_test_scenario_path)

return system, tests + plugin_tests, None
return system, list(set(tests + plugin_tests)), None

def _parse_with_scenario(
self,
Expand Down
7 changes: 3 additions & 4 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def parser(self, tmp_path: Path) -> Parser:
def test_no_tests_dir(self, parser: Parser):
tests_dir = parser.system_config_path.parent / "tests"
with pytest.raises(FileNotFoundError) as exc_info:
parser.parse(tests_dir, None, None)
parser.parse(tests_dir, None)
assert "Test path" in str(exc_info.value)

@patch("cloudai._core.test_parser.TestParser.parse_all")
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_scenario_with_plugin_common_tests(
fake_plugin.test_runs[0].test.name = "test-1"
parse_plugins.return_value = {"plugin-1": fake_plugin}

_, tests, _ = parser.parse(tests_dir, Path(), Path())
_, tests, _ = parser.parse(tests_dir, Path())

assert len(tests) == 1
assert tests[0].name == "test-1"
Expand All @@ -103,7 +103,6 @@ def test_scenario_with_plugin_common_tests(
def test_scenario_with_plugin_exclusive_tests(self, test_scenario_parser: Mock, test_parser: Mock, parser: Parser):
tests_dir = parser.system_config_path.parent.parent / "test"
test_scenario_path = Path("/mock/test_scenario.toml")
plugin_path = Path("/mock/plugin_scenarios")

fake_tests = [Mock() for _ in range(4)]
for i, test in enumerate(fake_tests):
Expand All @@ -119,7 +118,7 @@ def test_scenario_with_plugin_exclusive_tests(self, test_scenario_parser: Mock,
fake_plugin_scenarios["plugin-1"].test_runs[0].test.name = "test-2"

with patch.object(parser, "_load_plugin_scenarios", return_value=fake_plugin_scenarios):
_, filtered_tests, _ = parser.parse(tests_dir, test_scenario_path, plugin_path)
_, filtered_tests, _ = parser.parse(tests_dir, test_scenario_path)

filtered_test_names = {t.name for t in filtered_tests}
assert len(filtered_tests) == 2
Expand Down

0 comments on commit ae22fc1

Please sign in to comment.