Skip to content

Commit

Permalink
update tests to add SP process group
Browse files Browse the repository at this point in the history
  • Loading branch information
zzhhjjj committed Jul 2, 2024
1 parent 3f504b3 commit 3f841ba
Show file tree
Hide file tree
Showing 13 changed files with 97 additions and 85 deletions.
22 changes: 14 additions & 8 deletions examples/llama/tests/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _test_nt_to_hf(parallel_context: ParallelContext, input_ids: torch.Tensor):


def test_nt_to_hf(input_ids: torch.Tensor):
init_distributed(tp=1, dp=1, pp=1)(_test_nt_to_hf)(input_ids=input_ids)
init_distributed(tp=1, dp=1, pp=1, sp=1)(_test_nt_to_hf)(input_ids=input_ids)


def _test_nt_to_hf_with_files(parallel_context: ParallelContext, input_ids: torch.Tensor, test_context: TestContext):
Expand All @@ -130,7 +130,9 @@ def _test_nt_to_hf_with_files(parallel_context: ParallelContext, input_ids: torc


def test_nt_to_hf_with_files(input_ids: torch.Tensor):
init_distributed(tp=1, dp=1, pp=1)(_test_nt_to_hf_with_files)(input_ids=input_ids, test_context=TestContext())
init_distributed(tp=1, dp=1, pp=1, sp=1)(_test_nt_to_hf_with_files)(
input_ids=input_ids, test_context=TestContext()
)


def _test_hf_to_nt(parallel_context: ParallelContext, input_ids: torch.Tensor):
Expand All @@ -141,11 +143,11 @@ def _test_hf_to_nt(parallel_context: ParallelContext, input_ids: torch.Tensor):
logits_nt = model_nt.model(input_ids, input_mask).permute(1, 0, 2)
logits_hf = model_hf(input_ids).logits
assert logits_nt.size() == logits_hf.size()
torch.testing.assert_allclose(logits_hf, logits_nt, atol=ATOL)
torch.testing.assert_allclose(logits_hf, logits_nt, atol=ATOL)


def test_hf_to_nt(input_ids: torch.Tensor):
init_distributed(tp=1, dp=1, pp=1)(_test_hf_to_nt)(input_ids=input_ids)
init_distributed(tp=1, dp=1, pp=1, sp=1)(_test_hf_to_nt)(input_ids=input_ids)


def _test_hf_to_nt_with_files(parallel_context: ParallelContext, input_ids: torch.Tensor, test_context: TestContext):
Expand All @@ -168,7 +170,9 @@ def _test_hf_to_nt_with_files(parallel_context: ParallelContext, input_ids: torc


def test_hf_to_nt_with_files(input_ids: torch.Tensor):
init_distributed(tp=1, dp=1, pp=1)(_test_hf_to_nt_with_files)(input_ids=input_ids, test_context=TestContext())
init_distributed(tp=1, dp=1, pp=1, sp=1)(_test_hf_to_nt_with_files)(
input_ids=input_ids, test_context=TestContext()
)


def _test_composed_conversion(parallel_context: ParallelContext):
Expand Down Expand Up @@ -196,7 +200,7 @@ def _test_composed_conversion(parallel_context: ParallelContext):


def test_composed_conversion():
init_distributed(tp=1, dp=1, pp=1)(_test_composed_conversion)()
init_distributed(tp=1, dp=1, pp=1, sp=1)(_test_composed_conversion)()


def _save_parallel_nanotron(parallel_context: ParallelContext, input_ids: torch.Tensor, nt_path: Path):
Expand Down Expand Up @@ -239,9 +243,11 @@ def test_tensor_parallel_conversion(input_ids: torch.Tensor):
hf_path = root / "nanotron"

# Launch both parts.
init_distributed(tp=2, dp=1, pp=1)(_save_parallel_nanotron)(input_ids=input_ids, nt_path=nt_path)
init_distributed(tp=2, dp=1, pp=1, sp=1)(_save_parallel_nanotron)(input_ids=input_ids, nt_path=nt_path)
assert (nt_path / "logits.pt").exists()
init_distributed(tp=1, dp=1, pp=1)(_convert_from_parallel)(input_ids=input_ids, nt_path=nt_path, hf_path=hf_path)
init_distributed(tp=1, dp=1, pp=1, sp=1)(_convert_from_parallel)(
input_ids=input_ids, nt_path=nt_path, hf_path=hf_path
)
assert (hf_path / "logits.pt").exists()

# Load logits and verify they match.
Expand Down
10 changes: 6 additions & 4 deletions tests/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def is_dict_equal(first: Dict, second: Dict, sub_paths: Optional[List[str]] = No
return True, None


def get_all_3d_configurations(gpus: int) -> List[Tuple[int, int, int]]:
"""Given a number of gpus, we want all 3d configurations possible such that pp * dp * tp = gpus"""
def get_all_4d_configurations(gpus: int) -> List[Tuple[int, int, int, int]]:
"""Given a number of gpus, we want all 4d configurations possible such that pp * dp * tp * sp = gpus"""
result = []
for tp in range(1, gpus + 1):
if gpus % tp != 0:
Expand All @@ -121,8 +121,10 @@ def get_all_3d_configurations(gpus: int) -> List[Tuple[int, int, int]]:
for pp in range(1, gpus_left_after_dp + 1):
if gpus_left_after_dp % pp != 0:
continue
if tp * dp * pp == gpus:
result.append((pp, dp, tp))
gpus_left_after_pp = gpus_left_after_dp // pp
for sp in range(1, gpus_left_after_pp + 1):
if tp * dp * pp * sp == gpus:
result.append((tp, dp, pp, sp))
return result


Expand Down
30 changes: 16 additions & 14 deletions tests/nanoset/test_build_nanoset_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
from math import isclose
from pathlib import Path

package_path = Path(__file__).parent.parent
sys.path.append(str(package_path))

import numpy as np
import pytest
from helpers.context import TestContext
Expand All @@ -16,29 +13,32 @@
create_dummy_json_dataset,
preprocess_dummy_dataset,
)
from helpers.utils import available_gpus, get_all_3d_configurations, init_distributed, rerun_if_address_is_in_use
from helpers.utils import available_gpus, get_all_4d_configurations, init_distributed, rerun_if_address_is_in_use
from nanotron.data.dataloader_builder import build_nanoset_dataloader
from nanotron.data.nanoset import Nanoset
from nanotron.data.utils import count_dataset_indexes, normalize
from nanotron.parallel import ParallelContext
from nanotron.utils import main_rank_first
from transformers import AutoTokenizer

package_path = Path(__file__).parent.parent
sys.path.append(str(package_path))


@pytest.mark.parametrize(
"tp,dp,pp",
"tp,dp,pp,sp",
[
pytest.param(*all_3d_configs)
pytest.param(*all_4d_configs)
for gpus in range(1, min(available_gpus(), 4) + 1)
for all_3d_configs in get_all_3d_configurations(gpus)
for all_4d_configs in get_all_4d_configurations(gpus)
],
)
@pytest.mark.parametrize("train_steps", [5, 100])
@pytest.mark.parametrize("sequence_length", [512, 8192])
@pytest.mark.parametrize("tokenizer_name_or_path", ["openai-community/gpt2", "unsloth/llama-3-8b-bnb-4bit"])
@rerun_if_address_is_in_use()
def test_build_nanoset_dataloader(
tp: int, dp: int, pp: int, train_steps: int, sequence_length: int, tokenizer_name_or_path: str
tp: int, dp: int, pp: int, sp: int, train_steps: int, sequence_length: int, tokenizer_name_or_path: str
):
test_context = TestContext()

Expand All @@ -49,7 +49,7 @@ def test_build_nanoset_dataloader(
for idx, json_path in enumerate(json_paths):
create_dummy_json_dataset(path_to_json=json_path, dummy_text=f"Nanoset {idx}!", n_samples=(idx + 1) * 50000)

init_distributed(tp=tp, dp=dp, pp=pp)(_test_build_nanoset_dataloader)(
init_distributed(tp=tp, dp=dp, pp=pp, sp=sp)(_test_build_nanoset_dataloader)(
json_paths=json_paths,
path_to_mmap_files=mmap_dataset_paths,
train_steps=train_steps,
Expand Down Expand Up @@ -155,17 +155,19 @@ def _test_build_nanoset_dataloader(


@pytest.mark.parametrize(
"tp,dp,pp",
"tp,dp,pp,sp",
[
pytest.param(*all_3d_configs)
pytest.param(*all_4d_configs)
for gpus in range(1, min(available_gpus(), 4) + 1)
for all_3d_configs in get_all_3d_configurations(gpus)
for all_4d_configs in get_all_4d_configurations(gpus)
],
)
@pytest.mark.parametrize("skipped_batches", [20, 50])
@pytest.mark.parametrize("tokenizer_name_or_path", ["openai-community/gpt2", "unsloth/llama-3-8b-bnb-4bit"])
@rerun_if_address_is_in_use()
def test_recover_nanoset_dataloader(tp: int, dp: int, pp: int, skipped_batches: int, tokenizer_name_or_path: str):
def test_recover_nanoset_dataloader(
tp: int, dp: int, pp: int, sp: int, skipped_batches: int, tokenizer_name_or_path: str
):
test_context = TestContext()

# Create dataset files
Expand All @@ -175,7 +177,7 @@ def test_recover_nanoset_dataloader(tp: int, dp: int, pp: int, skipped_batches:
for idx, json_path in enumerate(json_paths):
create_dummy_json_dataset(path_to_json=json_path, dummy_text=f"Nanoset {idx}!", n_samples=(idx + 1) * 50000)

init_distributed(tp=tp, dp=dp, pp=pp)(_test_recover_nanoset_dataloader)(
init_distributed(tp=tp, dp=dp, pp=pp, sp=sp)(_test_recover_nanoset_dataloader)(
json_paths=json_paths,
path_to_mmap_files=mmap_dataset_paths,
skipped_batches=skipped_batches,
Expand Down
8 changes: 4 additions & 4 deletions tests/test_clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
@pytest.mark.parametrize("norm_type", [math.inf, 1.0, 2.0])
@rerun_if_address_is_in_use()
def test_clip_grads_with_pp(norm_type: float):
init_distributed(tp=1, dp=1, pp=2)(_test_clip_grads_with_pp)(norm_type=norm_type)
init_distributed(tp=1, dp=1, pp=2, sp=1)(_test_clip_grads_with_pp)(norm_type=norm_type)


def _test_clip_grads_with_pp(parallel_context: ParallelContext, norm_type: float):
Expand Down Expand Up @@ -203,7 +203,7 @@ def _test_clip_grads_with_pp(parallel_context: ParallelContext, norm_type: float
@pytest.mark.parametrize("norm_type", [math.inf, 1.0, 2.0])
@rerun_if_address_is_in_use()
def test_clip_grads_with_tp(tp_mode: TensorParallelLinearMode, async_communication: bool, norm_type: float):
init_distributed(tp=2, dp=1, pp=1)(_test_clip_grads_with_tp)(
init_distributed(tp=2, dp=1, pp=1, sp=1)(_test_clip_grads_with_tp)(
tp_mode=tp_mode, async_communication=async_communication, norm_type=norm_type
)

Expand Down Expand Up @@ -345,7 +345,7 @@ def _test_clip_grads_with_tp(
@pytest.mark.parametrize("norm_type", [math.inf, 1.0, 2.0])
@rerun_if_address_is_in_use()
def test_clip_grads_tied_weights(norm_type: float):
init_distributed(tp=1, dp=1, pp=2)(_test_clip_grads_tied_weights)(norm_type=norm_type)
init_distributed(tp=1, dp=1, pp=2, sp=1)(_test_clip_grads_tied_weights)(norm_type=norm_type)


def _test_clip_grads_tied_weights(parallel_context: ParallelContext, norm_type: float):
Expand Down Expand Up @@ -438,7 +438,7 @@ def _test_clip_grads_tied_weights(parallel_context: ParallelContext, norm_type:
@pytest.mark.parametrize("norm_type", [math.inf, 1.0, 2.0])
@rerun_if_address_is_in_use()
def test_clip_grads_fp32_accumulator(norm_type: float, half_precision: torch.dtype):
init_distributed(tp=1, dp=1, pp=2)(_test_clip_grads_fp32_accumulator)(
init_distributed(tp=1, dp=1, pp=2, sp=1)(_test_clip_grads_fp32_accumulator)(
norm_type=norm_type, half_precision=half_precision
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
@pytest.mark.parametrize("accumulation_steps", [1, 3])
@rerun_if_address_is_in_use()
def test_ddp_with_afab(accumulation_steps):
init_distributed(tp=1, dp=2, pp=1)(_test_ddp_with_afab)(accumulation_steps=accumulation_steps)
init_distributed(tp=1, dp=2, pp=1, sp=1)(_test_ddp_with_afab)(accumulation_steps=accumulation_steps)


def _test_ddp_with_afab(parallel_context: ParallelContext, accumulation_steps: int):
Expand Down
12 changes: 6 additions & 6 deletions tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.distributed as dist
from helpers.utils import (
available_gpus,
get_all_3d_configurations,
get_all_4d_configurations,
init_distributed,
rerun_if_address_is_in_use,
)
Expand Down Expand Up @@ -36,13 +36,13 @@ def _test_init_parallel_context(parallel_context: ParallelContext):


@pytest.mark.parametrize(
"tp,dp,pp",
"tp,dp,pp,sp",
[
pytest.param(*all_3d_configs)
pytest.param(*all_4d_configs)
for gpus in range(1, min(available_gpus(), 4) + 1)
for all_3d_configs in get_all_3d_configurations(gpus)
for all_4d_configs in get_all_4d_configurations(gpus)
],
)
@rerun_if_address_is_in_use()
def test_init_parallel_context(tp: int, dp: int, pp: int):
init_distributed(tp=tp, dp=dp, pp=pp)(_test_init_parallel_context)()
def test_init_parallel_context(tp: int, dp: int, pp: int, sp: int):
init_distributed(tp=tp, dp=dp, pp=pp, sp=sp)(_test_init_parallel_context)()
2 changes: 1 addition & 1 deletion tests/test_optimizer_params_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def optimizer_builder(inp_param_groups):
def test_ddp_optimizer_grad_accumulation_lr_weight_decay_multiple_group(
half_precision: torch.dtype, accumulation_steps: int
):
init_distributed(tp=1, dp=2, pp=1)(_test_ddp_optimizer_grad_accumulation_lr_weight_decay_multiple_group)(
init_distributed(tp=1, dp=2, pp=1, sp=1)(_test_ddp_optimizer_grad_accumulation_lr_weight_decay_multiple_group)(
half_precision=half_precision,
accumulation_steps=accumulation_steps,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
@pytest.mark.parametrize("full", [True, False])
@rerun_if_address_is_in_use()
def test_check_send_recv_tensor(send_contiguous: bool, full: bool):
init_distributed(tp=1, dp=1, pp=2)(_test_check_send_recv_tensor)(send_contiguous=send_contiguous, full=full)
init_distributed(tp=1, dp=1, pp=2, sp=1)(_test_check_send_recv_tensor)(send_contiguous=send_contiguous, full=full)


def _test_check_send_recv_tensor(parallel_context: ParallelContext, send_contiguous: bool, full: bool):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_parameters_accumulate_gradient_in_fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def test_optimizer_can_step_gradient_in_fp32(half_precision: torch.dtype):
@pytest.mark.parametrize("train_iterations", [1, 3])
@rerun_if_address_is_in_use()
def test_ddp_with_grad_accum_in_fp32(half_precision: torch.dtype, accumulation_steps: int, train_iterations: int):
init_distributed(tp=1, dp=2, pp=1)(_test_ddp_with_grad_accum_in_fp32)(
init_distributed(tp=1, dp=2, pp=1, sp=1)(_test_ddp_with_grad_accum_in_fp32)(
half_precision=half_precision,
accumulation_steps=accumulation_steps,
train_iterations=train_iterations,
Expand Down Expand Up @@ -257,7 +257,7 @@ def _test_ddp_with_grad_accum_in_fp32(
accumulator.backward(loss_fp32_accum)

for name, param in model_ddp_fp32_accum.named_parameters():
# Check that half grads has been set to None in sync step, to avoid it being uncorrectly used
# Check that half grads has been set to None in sync step, to avoid it being incorrectly used
half_grad = param.grad
assert half_grad is None, f"{half_grad} != None"

Expand Down Expand Up @@ -310,7 +310,7 @@ def _test_ddp_with_grad_accum_in_fp32(
@pytest.mark.parametrize("reduce_scatter", [True, False])
@rerun_if_address_is_in_use()
def test_tied_weights_sync_with_grad_accum_in_fp32(pipeline_engine: PipelineEngine, reduce_scatter: bool):
init_distributed(tp=1, dp=2, pp=2)(_test_tied_weights_sync_with_grad_accum_in_fp32)(
init_distributed(tp=1, dp=2, pp=2, sp=1)(_test_tied_weights_sync_with_grad_accum_in_fp32)(
pipeline_engine=pipeline_engine, reduce_scatter=reduce_scatter
)

Expand Down
10 changes: 5 additions & 5 deletions tests/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
@pytest.mark.skipif(available_gpus() < 2, reason="Testing build_and_set_rank requires at least 2 gpus")
@rerun_if_address_is_in_use()
def test_build_and_set_rank():
init_distributed(tp=1, dp=1, pp=2)(_test_build_and_set_rank)()
init_distributed(tp=1, dp=1, pp=2, sp=1)(_test_build_and_set_rank)()


def _test_build_and_set_rank(parallel_context: ParallelContext):
Expand Down Expand Up @@ -72,7 +72,7 @@ def test_init_on_device_and_dtype():
@pytest.mark.parametrize("pp", list(range(2, min(4, available_gpus()) + 1)))
@rerun_if_address_is_in_use()
def test_pipeline_engine(pipeline_engine: PipelineEngine, pp: int):
init_distributed(tp=1, dp=1, pp=pp)(_test_pipeline_engine)(pipeline_engine=pipeline_engine)
init_distributed(tp=1, dp=1, pp=pp, sp=1)(_test_pipeline_engine)(pipeline_engine=pipeline_engine)


def _test_pipeline_engine(parallel_context: ParallelContext, pipeline_engine: PipelineEngine):
Expand Down Expand Up @@ -217,7 +217,7 @@ def _test_pipeline_engine(parallel_context: ParallelContext, pipeline_engine: Pi
@pytest.mark.parametrize("pp", list(range(2, min(4, available_gpus()) + 1)))
@rerun_if_address_is_in_use()
def test_pipeline_engine_with_tensor_that_does_not_require_grad(pipeline_engine: PipelineEngine, pp: int):
init_distributed(pp=pp, dp=1, tp=1)(_test_pipeline_engine_with_tensor_that_does_not_require_grad)(
init_distributed(pp=pp, dp=1, tp=1, sp=1)(_test_pipeline_engine_with_tensor_that_does_not_require_grad)(
pipeline_engine=pipeline_engine
)

Expand Down Expand Up @@ -448,7 +448,7 @@ def dummy_infinite_data_loader_with_non_differentiable_tensor(
@pytest.mark.parametrize("pp", list(range(2, min(4, available_gpus()) + 1)))
@rerun_if_address_is_in_use()
def test_pipeline_forward_without_engine(pp: int):
init_distributed(pp=pp, dp=1, tp=1)(_test_pipeline_forward_without_engine)()
init_distributed(pp=pp, dp=1, tp=1, sp=1)(_test_pipeline_forward_without_engine)()


def _test_pipeline_forward_without_engine(parallel_context: ParallelContext):
Expand Down Expand Up @@ -623,7 +623,7 @@ def dummy_infinite_data_loader_with_non_differentiable_tensor(
)
@rerun_if_address_is_in_use()
def test_pipeline_engine_diamond(pipeline_engine: PipelineEngine):
init_distributed(pp=4, dp=1, tp=1)(_test_pipeline_engine_diamond)(pipeline_engine=pipeline_engine)
init_distributed(pp=4, dp=1, tp=1, sp=1)(_test_pipeline_engine_diamond)(pipeline_engine=pipeline_engine)
pass


Expand Down
Loading

0 comments on commit 3f841ba

Please sign in to comment.