Skip to content

Commit

Permalink
add reruning a tests if a port is in used
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Feb 10, 2024
1 parent 29672db commit 0a34e65
Show file tree
Hide file tree
Showing 12 changed files with 160 additions and 10 deletions.
117 changes: 116 additions & 1 deletion tests/helpers/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import contextlib
import os
import re
import uuid
from typing import Any, Dict, List, Optional, Tuple
from inspect import signature
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch.cuda
from nanotron.parallel import ParallelContext
from packaging import version
from torch.distributed.launcher import elastic_launch


Expand Down Expand Up @@ -185,3 +188,115 @@ def get_all_3d_configurations(gpus: int) -> List[Tuple[int, int, int]]:
if tp * dp * pp == gpus:
result.append((pp, dp, tp))
return result


def rerun_if_address_is_in_use():
"""
This function reruns a wrapped function if "address already in use" occurs
in testing spawned with torch.multiprocessing
Usage::
@rerun_if_address_is_in_use()
def test_something():
...
"""
# check version
torch_version = version.parse(torch.__version__)
assert torch_version.major >= 1

# only torch >= 1.8 has ProcessRaisedException
if torch_version >= version.parse("1.8.0"):
exception = torch.multiprocessing.ProcessRaisedException
else:
exception = Exception

func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*Address already in use.*")
return func_wrapper


def rerun_on_exception(exception_type: Exception = Exception, pattern: str = None, max_try: int = 5) -> Callable:
"""
A decorator on a function to re-run when an exception occurs.
Usage::
# rerun for all kinds of exception
@rerun_on_exception()
def test_method():
print('hey')
raise RuntimeError('Address already in use')
# rerun for RuntimeError only
@rerun_on_exception(exception_type=RuntimeError)
def test_method():
print('hey')
raise RuntimeError('Address already in use')
# rerun for maximum 10 times if Runtime error occurs
@rerun_on_exception(exception_type=RuntimeError, max_try=10)
def test_method():
print('hey')
raise RuntimeError('Address already in use')
# rerun for infinite times if Runtime error occurs
@rerun_on_exception(exception_type=RuntimeError, max_try=None)
def test_method():
print('hey')
raise RuntimeError('Address already in use')
# rerun only the exception message is matched with pattern
# for infinite times if Runtime error occurs
@rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$")
def test_method():
print('hey')
raise RuntimeError('Address already in use')
Args:
exception_type (Exception, Optional): The type of exception to detect for rerun
pattern (str, Optional): The pattern to match the exception message.
If the pattern is not None and matches the exception message,
the exception will be detected for rerun
max_try (int, Optional): Maximum reruns for this function. The default value is 5.
If max_try is None, it will rerun forever if exception keeps occurring
"""

def _match_lines(lines, pattern):
for line in lines:
if re.match(pattern, line):
return True
return False

def _wrapper(func):
def _run_until_success(*args, **kwargs):
try_count = 0
assert max_try is None or isinstance(
max_try, int
), f"Expected max_try to be None or int, but got {type(max_try)}"

while max_try is None or try_count < max_try:
try:
try_count += 1
ret = func(*args, **kwargs)
return ret
except exception_type as e:
error_lines = str(e).split("\n")
if try_count < max_try and (pattern is None or _match_lines(error_lines, pattern)):
print("Exception is caught, retrying...")
# when pattern is not specified, we always skip the exception
# when pattern is specified, we only skip when pattern is matched
continue
else:
print("Maximum number of attempts is reached or pattern is not matched, no more retrying...")
raise e

# Override signature
# otherwise pytest.mark.parameterize will raise the following error:
# function does not use argument xxx
sig = signature(func)
_run_until_success.__signature__ = sig

return _run_until_success

return _wrapper
6 changes: 5 additions & 1 deletion tests/test_clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import torch
from helpers.dummy import DummyModel, dummy_infinite_data_loader
from helpers.utils import available_gpus, init_distributed
from helpers.utils import available_gpus, init_distributed, rerun_if_address_is_in_use
from nanotron import distributed as dist
from nanotron.models import init_on_device_and_dtype
from nanotron.optim.clip_grads import clip_grad_norm
Expand Down Expand Up @@ -32,6 +32,7 @@

@pytest.mark.skipif(available_gpus() < 2, reason="test_clip_grads_with_pp requires at least 2 gpus")
@pytest.mark.parametrize("norm_type", [math.inf, 1.0, 2.0])
@rerun_if_address_is_in_use()
def test_clip_grads_with_pp(norm_type: float):
init_distributed(tp=1, dp=1, pp=2)(_test_clip_grads_with_pp)(norm_type=norm_type)

Expand Down Expand Up @@ -198,6 +199,7 @@ def _test_clip_grads_with_pp(parallel_context: ParallelContext, norm_type: float
],
)
@pytest.mark.parametrize("norm_type", [math.inf, 1.0, 2.0])
@rerun_if_address_is_in_use()
def test_clip_grads_with_tp(tp_mode: TensorParallelLinearMode, async_communication: bool, norm_type: float):
init_distributed(tp=2, dp=1, pp=1)(_test_clip_grads_with_tp)(
tp_mode=tp_mode, async_communication=async_communication, norm_type=norm_type
Expand Down Expand Up @@ -339,6 +341,7 @@ def _test_clip_grads_with_tp(

@pytest.mark.skipif(available_gpus() < 2, reason="test_clip_grads_tied_weights requires at least 2 gpus")
@pytest.mark.parametrize("norm_type", [math.inf, 1.0, 2.0])
@rerun_if_address_is_in_use()
def test_clip_grads_tied_weights(norm_type: float):
init_distributed(tp=1, dp=1, pp=2)(_test_clip_grads_tied_weights)(norm_type=norm_type)

Expand Down Expand Up @@ -434,6 +437,7 @@ def _test_clip_grads_tied_weights(parallel_context: ParallelContext, norm_type:

@pytest.mark.parametrize("half_precision", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("norm_type", [math.inf, 1.0, 2.0])
@rerun_if_address_is_in_use()
def test_clip_grads_fp32_accumulator(norm_type: float, half_precision: torch.dtype):
init_distributed(tp=1, dp=1, pp=2)(_test_clip_grads_fp32_accumulator)(
norm_type=norm_type, half_precision=half_precision
Expand Down
3 changes: 2 additions & 1 deletion tests/test_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import torch
from helpers.exception import assert_fail_except_rank_with
from helpers.utils import available_gpus, init_distributed
from helpers.utils import available_gpus, init_distributed, rerun_if_address_is_in_use
from nanotron import distributed as dist
from nanotron.parallel import ParallelContext
from nanotron.parallel.data_parallel.utils import ddp_trigger_sync_in_bwd
Expand All @@ -15,6 +15,7 @@

@pytest.mark.skipif(available_gpus() < 2, reason="Testing test_ddp_with_afab requires at least 2 gpus")
@pytest.mark.parametrize("accumulation_steps", [1, 3])
@rerun_if_address_is_in_use()
def test_ddp_with_afab(accumulation_steps):
init_distributed(tp=1, dp=2, pp=1)(_test_ddp_with_afab)(accumulation_steps=accumulation_steps)

Expand Down
2 changes: 2 additions & 0 deletions tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
available_gpus,
get_all_3d_configurations,
init_distributed,
rerun_if_address_is_in_use,
)
from nanotron.parallel import ParallelContext
from torch.distributed import ProcessGroup
Expand Down Expand Up @@ -32,5 +33,6 @@ def _test_init_parallel_context(parallel_context: ParallelContext):
for all_3d_configs in get_all_3d_configurations(gpus)
],
)
@rerun_if_address_is_in_use()
def test_init_parallel_context(tp: int, dp: int, pp: int):
init_distributed(tp=tp, dp=dp, pp=pp)(_test_init_parallel_context)()
3 changes: 2 additions & 1 deletion tests/test_p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import torch
from helpers.exception import assert_fail_with
from helpers.utils import available_gpus, init_distributed
from helpers.utils import available_gpus, init_distributed, rerun_if_address_is_in_use
from nanotron import distributed as dist
from nanotron.parallel import ParallelContext
from nanotron.parallel.pipeline_parallel.p2p import P2P
Expand All @@ -12,6 +12,7 @@
@pytest.mark.skipif(available_gpus() < 2, reason="Testing test_ddp_with_afab requires at least 2 gpus")
@pytest.mark.parametrize("send_contiguous", [True, False])
@pytest.mark.parametrize("full", [True, False])
@rerun_if_address_is_in_use()
def test_check_send_recv_tensor(send_contiguous: bool, full: bool):
init_distributed(tp=1, dp=1, pp=2)(_test_check_send_recv_tensor)(send_contiguous=send_contiguous, full=full)

Expand Down
4 changes: 3 additions & 1 deletion tests/test_parameters_accumulate_gradient_in_fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from helpers.dummy import DummyModel, dummy_infinite_data_loader
from helpers.exception import assert_fail_except_rank_with, timeout_after
from helpers.utils import available_gpus, init_distributed
from helpers.utils import available_gpus, init_distributed, rerun_if_address_is_in_use
from nanotron.models import init_on_device_and_dtype
from nanotron.optim import ZeroDistributedOptimizer
from nanotron.optim.gradient_accumulator import FP32GradBucketManager, FP32GradientAccumulator, get_fp32_accum_hook
Expand Down Expand Up @@ -141,6 +141,7 @@ def test_optimizer_can_step_gradient_in_fp32(half_precision: torch.dtype):
@pytest.mark.parametrize("half_precision", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("accumulation_steps", [1, 10])
@pytest.mark.parametrize("train_iterations", [1, 3])
@rerun_if_address_is_in_use()
def test_ddp_with_grad_accum_in_fp32(half_precision: torch.dtype, accumulation_steps: int, train_iterations: int):
init_distributed(tp=1, dp=2, pp=1)(_test_ddp_with_grad_accum_in_fp32)(
half_precision=half_precision,
Expand Down Expand Up @@ -306,6 +307,7 @@ def _test_ddp_with_grad_accum_in_fp32(
"pipeline_engine", [AllForwardAllBackwardPipelineEngine(), OneForwardOneBackwardPipelineEngine()]
)
@pytest.mark.parametrize("reduce_scatter", [True, False])
@rerun_if_address_is_in_use()
def test_tied_weights_sync_with_grad_accum_in_fp32(pipeline_engine: PipelineEngine, reduce_scatter: bool):
init_distributed(tp=1, dp=2, pp=2)(_test_tied_weights_sync_with_grad_accum_in_fp32)(
pipeline_engine=pipeline_engine, reduce_scatter=reduce_scatter
Expand Down
7 changes: 6 additions & 1 deletion tests/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import torch
from helpers.dummy import DummyModel, dummy_infinite_data_loader
from helpers.utils import available_gpus, init_distributed
from helpers.utils import available_gpus, init_distributed, rerun_if_address_is_in_use
from nanotron import distributed as dist
from nanotron.models import init_on_device_and_dtype
from nanotron.parallel import ParallelContext
Expand All @@ -20,6 +20,7 @@


@pytest.mark.skipif(available_gpus() < 2, reason="Testing build_and_set_rank requires at least 2 gpus")
@rerun_if_address_is_in_use()
def test_build_and_set_rank():
init_distributed(tp=1, dp=1, pp=2)(_test_build_and_set_rank)()

Expand Down Expand Up @@ -67,6 +68,7 @@ def test_init_on_device_and_dtype():
"pipeline_engine", [AllForwardAllBackwardPipelineEngine(), OneForwardOneBackwardPipelineEngine()]
)
@pytest.mark.parametrize("pp", list(range(2, min(4, available_gpus()) + 1)))
@rerun_if_address_is_in_use()
def test_pipeline_engine(pipeline_engine: PipelineEngine, pp: int):
init_distributed(tp=1, dp=1, pp=pp)(_test_pipeline_engine)(pipeline_engine=pipeline_engine)

Expand Down Expand Up @@ -209,6 +211,7 @@ def _test_pipeline_engine(parallel_context: ParallelContext, pipeline_engine: Pi
"pipeline_engine", [AllForwardAllBackwardPipelineEngine(), OneForwardOneBackwardPipelineEngine()]
)
@pytest.mark.parametrize("pp", list(range(2, min(4, available_gpus()) + 1)))
@rerun_if_address_is_in_use()
def test_pipeline_engine_with_tensor_that_does_not_require_grad(pipeline_engine: PipelineEngine, pp: int):
init_distributed(pp=pp, dp=1, tp=1)(_test_pipeline_engine_with_tensor_that_does_not_require_grad)(
pipeline_engine=pipeline_engine
Expand Down Expand Up @@ -438,6 +441,7 @@ def dummy_infinite_data_loader_with_non_differentiable_tensor(


@pytest.mark.parametrize("pp", list(range(2, min(4, available_gpus()) + 1)))
@rerun_if_address_is_in_use()
def test_pipeline_forward_without_engine(pp: int):
init_distributed(pp=pp, dp=1, tp=1)(_test_pipeline_forward_without_engine)()

Expand Down Expand Up @@ -610,6 +614,7 @@ def dummy_infinite_data_loader_with_non_differentiable_tensor(
@pytest.mark.parametrize(
"pipeline_engine", [AllForwardAllBackwardPipelineEngine(), OneForwardOneBackwardPipelineEngine()]
)
@rerun_if_address_is_in_use()
def test_pipeline_engine_diamond(pipeline_engine: PipelineEngine):
init_distributed(pp=4, dp=1, tp=1)(_test_pipeline_engine_diamond)(pipeline_engine=pipeline_engine)
pass
Expand Down
3 changes: 2 additions & 1 deletion tests/test_random_state.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import torch
from helpers.utils import available_gpus, init_distributed
from helpers.utils import available_gpus, init_distributed, rerun_if_address_is_in_use
from nanotron import distributed as dist
from nanotron.parallel import ParallelContext
from nanotron.random import (
Expand All @@ -13,6 +13,7 @@

@pytest.mark.skipif(available_gpus() < 2, reason="Testing test_random_state_sync requires at least 2 gpus")
@pytest.mark.parametrize("tp,dp,pp", [(2, 1, 1), (1, 2, 1), (1, 1, 2)])
@rerun_if_address_is_in_use()
def test_random_state_sync(tp: int, dp: int, pp: int):
# TODO @nouamane: Make a test with 4 gpus (2 in one pg, 2 in other pg)
init_distributed(tp=tp, dp=dp, pp=pp)(_test_random_state_sync)()
Expand Down
9 changes: 9 additions & 0 deletions tests/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
get_all_3d_configurations,
init_distributed,
is_dict_equal,
rerun_if_address_is_in_use,
)
from nanotron import distributed as dist
from nanotron.constants import CHECKPOINT_VERSION
Expand Down Expand Up @@ -48,6 +49,7 @@ def test_save_and_load_with_changed_topolgy():
for all_3d_configs in get_all_3d_configurations(gpus)
],
)
@rerun_if_address_is_in_use()
def test_save_and_load_model(tp: int, dp: int, pp: int):
test_context = TestContext()
# We use DP=2 as we're interested in testing that one
Expand Down Expand Up @@ -87,6 +89,7 @@ def _test_save_and_load_model(parallel_context: ParallelContext, test_context: T
for all_3d_configs in get_all_3d_configurations(gpus)
],
)
@rerun_if_address_is_in_use()
def test_save_and_load_optimizer(tp: int, dp: int, pp: int):
test_context = TestContext()
# We use DP=2 as we're interested in testing that one
Expand Down Expand Up @@ -149,6 +152,7 @@ def _test_save_and_load_optimizer(parallel_context: ParallelContext, test_contex
for all_3d_configs in get_all_3d_configurations(gpus)
],
)
@rerun_if_address_is_in_use()
def test_save_zero_optimizer_and_load_optimizer(tp: int, dp: int, pp: int):
test_context = TestContext()
# We use DP=2 as we're interested in testing that one
Expand Down Expand Up @@ -220,6 +224,7 @@ def _test_save_zero_optimizer_and_load_optimizer(parallel_context: ParallelConte
for all_3d_configs in get_all_3d_configurations(gpus)
],
)
@rerun_if_address_is_in_use()
def test_save_zero_optimizer_and_load_data_parallel_optimizer(tp: int, dp: int, pp: int):
test_context = TestContext()
# We use DP=2 as we're interested in testing that one
Expand Down Expand Up @@ -289,6 +294,7 @@ def _test_save_zero_optimizer_and_load_data_parallel_optimizer(
for all_3d_configs in get_all_3d_configurations(gpus)
],
)
@rerun_if_address_is_in_use()
def test_save_data_parallel_optimizer_and_load_zero_optimizer(tp: int, dp: int, pp: int):
test_context = TestContext()
# We use DP=2 as we're interested in testing that one
Expand Down Expand Up @@ -354,6 +360,7 @@ def _test_save_data_parallel_optimizer_and_load_zero_optimizer(
for all_3d_configs in get_all_3d_configurations(gpus)
],
)
@rerun_if_address_is_in_use()
def test_save_optimizer_with_additional_state_dict_keys(tp: int, dp: int, pp: int):
test_context = TestContext()
# We use DP=2 as we're interested in testing that one
Expand Down Expand Up @@ -459,6 +466,7 @@ def _test_save_optimizer_with_additional_state_dict_keys(parallel_context: Paral


@pytest.mark.skipif(available_gpus() < 2, reason="Testing test_save_and_load_random_states requires at least 2 gpus")
@rerun_if_address_is_in_use()
def test_save_and_load_random_states():
test_context = TestContext()
# We use DP=2 as we're interested in testing
Expand Down Expand Up @@ -496,6 +504,7 @@ def _test_save_and_load_random_states(parallel_context: ParallelContext, test_co
assert random_states == new_random_states


@rerun_if_address_is_in_use()
def test_serialize_deserialize_tensormetadata():
test_context = TestContext()
init_distributed(tp=2, dp=1, pp=1)(_test_serialize_deserialize_tensormetadata)(test_context=test_context)
Expand Down
Loading

0 comments on commit 0a34e65

Please sign in to comment.