diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml new file mode 100644 index 00000000..126bd25e --- /dev/null +++ b/.github/workflows/lint.yaml @@ -0,0 +1,43 @@ +name: lint + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + check_code_quality: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python + id: setup_python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + - name: Load cached virtual environment + uses: actions/cache@v3 + id: cache-venv + with: + path: | + ~/.venv/ + ~/.cache/pre-commit/ + .git/hooks/pre-commit + key: ${{ runner.os }}-${{ steps.setup_python.outputs.python-version }}-venv-${{ hashFiles('pyproject.toml') }} + - name: Install dependencies + run: | + python -m venv ~/.venv + source ~/.venv/bin/activate + python -m pip install -e .[dev] + pre-commit install + if: steps.cache-venv.outputs.cache-hit != 'true' + - name: Check quality + run: | + source ~/.venv/bin/activate + python -m pip install --no-deps -e .[dev] + pre-commit run --config .pre-commit-config-check.yaml --all-files diff --git a/README.md b/README.md index 8cded785..8e0213c3 100644 --- a/README.md +++ b/README.md @@ -1 +1 @@ -# nanotron \ No newline at end of file +# nanotron diff --git a/configs/config.yaml b/configs/config.yaml new file mode 100644 index 00000000..129f2439 --- /dev/null +++ b/configs/config.yaml @@ -0,0 +1,111 @@ +# USE_FAST=1 CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=8 scripts/train.py --config-file configs/config.yaml +# 09/25/2023 09:55:06 [INFO|DP=0|PP=0|TP=0]: [After train batch iter] Memory usage: 18459.78MB. Peak reserved memory: 69208.00MB +# 09/25/2023 09:55:07 [INFO|DP=0|PP=1|TP=0]: iteration: 2 / 300 | consumed_samples: 1024 | elapsed_time_per_iteration_ms: 58748.9 | tokens_per_sec: 3.569689E+04 | tokens_per_sec_per_gpu: 4.462111E+03 | global_batch_size: 512 | lm_loss: 1.130280E+01 | lr: 5.333E-07 | model_tflops_per_gpu: 185.96 | hardware_tflops_per_gpu: 195.54 | grad_norm: 1.618 +general: + name: test-llama + ignore_sanity_checks: false + kill_switch_path: ./kill_switch_nouamane + +profile: # + # profiler_export_path: profile + +checkpoints: + checkpoints_path: /fsx/nouamane/checkpoints/nanotron/test + load_from_specific_checkpoint: null + checkpoint_interval: 1000000 + +parallelism: + dp: 2 + pp: 2 + tp: 2 + pp_engine: 1f1b + tp_mode: REDUCE_SCATTER + tp_linear_async_communication: true + recompute_granularity: selective + +model: + hf_model_name: HuggingFaceBR4/llama-v2-7b-the-pile + # hf_model_name: huggyllama/llama-7b + # hf_model_name: meta-llama/Llama-2-7b-hf + remote_code: + trust_remote_code: true + make_vocab_size_divisible_by: 1 + init_method: + std: 0.015625 # Basically 1/sqrt(N) + # path: /fsx/nouamane/projects/nanotron/pretrained/llama-v2-7b-the-pile + # path: /fsx/nouamane/projects/nanotron/pretrained/llama-2-7b + dtype: bfloat16 + seed: 42 + +tokens: + sequence_length: 4096 + train_steps: 300 # GBS = 1024 -> Train steps = 111998 / 512 = 160 + micro_batch_size: 4 + batch_accumulation_per_replica: 64 + val_check_interval: 20 + limit_val_batches: 2 + +optimizer: + zero_stage: 1 + weight_decay: 0.1 + clip_grad: 0.4 + + accumulate_grad_in_fp32: true + + adam_eps: 1.0e-8 + adam_beta1: 0.9 + adam_beta2: 0.95 # Copied from LLaMa + torch_adam_is_fused: true + + learning_rate: 4.0e-4 + +learning_rate_scheduler: + lr_warmup_steps: 1500 + lr_warmup_style: linear + lr_decay_steps: null + lr_decay_style: linear + min_decay_lr: 4.0e-5 + +logging: + # 'debug', 'info', 'warning', 'error', 'critical' and 'passive' + log_level: 'info' + log_level_replica: 'info' + iteration_step_info_interval: 1 + tensorboard_logger: + # tensorboard_dir: ./tensorboard_llama + # # flush_secs: 20 + # repo_id: HuggingFaceBR4/nouamane-llama-2-finetuning-clean + # push_to_hub_interval: 20 + # repo_public: False + +data: + seed: 1234 + num_loading_workers: 1 + dataset: + # hf_dataset_mixer: + # # HuggingFaceH4/oasst1_h4: 1.0 # 20504 -> 20k + # HuggingFaceH4/anthropic_helpful: 1.0 # 111998 -> 20k + # # HuggingFaceH4/shp: 0 # 82836 -> 20k + # # HuggingFaceH4/learn_to_summarize: 0.527 # 37962 -> 20k + # # HuggingFaceH4/scale_helpful_1: 1.0 # 800 + # hf_dataset_splits: + # - train_ift + # # - train_rm + # # - test_rm # # TODO @nouamane: support evaluation + # hf_dataset_config_name: null + # dataset_processing_num_proc_per_process: 12 + # dataset_overwrite_cache: false + # text_column_name: chosen + + # data_prefix: + # - 1 + # - /fsx/thomwolf/data/llama-samantha_result_document + # index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix + # splits_string: 0.969,0.03,0.001 # train, val, test (we normalize by sum) + # # rm /fsx/shared-falcon-180B/data/tokenized_stack_no_pii/code/python/*.npy to reset cache + # skip_warmup: true + # dataloader_type: single # cyclic + # validation_drop_last: true # Set to false if the last partial validation samples is to be consumed + # eod_mask_loss: false # Mask loss for the end of document tokens + # no_seqlen_plus_one_input_tokens: false # Set to true to disable fetching (sequence length + 1) input tokens, instead get (sequence length) input tokens and mask the last token + # pad_samples_to_global_batch_size: false # Set to true if you want to pad the last partial batch with -1's to equal global batch size diff --git a/configs/config_correctness.yaml b/configs/config_correctness.yaml new file mode 100644 index 00000000..a7281c55 --- /dev/null +++ b/configs/config_correctness.yaml @@ -0,0 +1,109 @@ +# USE_FAST=1 CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=8 scripts/train.py --config-file configs/config_correctness.yaml +general: + name: test-llama + ignore_sanity_checks: false + kill_switch_path: ./kill_switch_nouamane + +profile: # + # profiler_export_path: profile + +checkpoints: + checkpoints_path: /fsx/nouamane/checkpoints/nanotron/test + load_from_specific_checkpoint: null + checkpoint_interval: 1000000 + +parallelism: + dp: 2 + pp: 2 + tp: 2 + pp_engine: 1f1b + tp_mode: REDUCE_SCATTER + tp_linear_async_communication: true + recompute_granularity: selective + +model: + # hf_model_name: HuggingFaceBR4/llama-v2-7b-the-pile + # hf_model_name: huggyllama/llama-7b + hf_model_name: meta-llama/Llama-2-7b-hf + remote_code: + trust_remote_code: true + make_vocab_size_divisible_by: 1 + init_method: + # std: 0.015625 # Basically 1/sqrt(N) + # path: /fsx/nouamane/projects/nanotron/pretrained/llama-v2-7b-the-pile + path: /fsx/nouamane/projects/brrr/pretrained/llama-2-7b + dtype: bfloat16 + seed: 42 + +tokens: + sequence_length: 4096 + train_steps: 300 # GBS = 1024 -> Train steps = 111998 / 512 = 160 + micro_batch_size: 2 + batch_accumulation_per_replica: 3 + val_check_interval: 20 + limit_val_batches: 2 + +optimizer: + zero_stage: 1 + weight_decay: 0.1 + clip_grad: 0.4 + + accumulate_grad_in_fp32: true + + adam_eps: 1.0e-8 + adam_beta1: 0.9 + adam_beta2: 0.95 # Copied from LLaMa + torch_adam_is_fused: true + + learning_rate: 4.0e-4 + +learning_rate_scheduler: + lr_warmup_steps: 1500 + lr_warmup_style: linear + lr_decay_steps: null + lr_decay_style: linear + min_decay_lr: 4.0e-5 + +logging: + # 'debug', 'info', 'warning', 'error', 'critical' and 'passive' + log_level: 'info' + log_level_replica: 'info' + iteration_step_info_interval: 1 + tensorboard_logger: + # tensorboard_dir: /fsx/nouamane/projects/nanotron/tb_logs + # # flush_secs: 20 + # repo_id: HuggingFaceBR4/nouamane-llama-2-finetuning-clean + # push_to_hub_interval: 20 + # repo_public: False + +data: + seed: 1234 + num_loading_workers: 1 + dataset: + # hf_dataset_mixer: + # # HuggingFaceH4/oasst1_h4: 1.0 # 20504 -> 20k + # HuggingFaceH4/anthropic_helpful: 1.0 # 111998 -> 20k + # # HuggingFaceH4/shp: 0 # 82836 -> 20k + # # HuggingFaceH4/learn_to_summarize: 0.527 # 37962 -> 20k + # # HuggingFaceH4/scale_helpful_1: 1.0 # 800 + # hf_dataset_splits: + # - train_ift + # # - train_rm + # # - test_rm # # TODO @nouamane: support evaluation + # hf_dataset_config_name: null + # dataset_processing_num_proc_per_process: 12 + # dataset_overwrite_cache: false + # text_column_name: chosen + + data_prefix: + - 1 + - /fsx/nouamane/data/llama-samantha/llama-samantha_result_document + index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix + splits_string: 0.969,0.03,0.001 # train, val, test (we normalize by sum) + # rm /fsx/shared-falcon-180B/data/tokenized_stack_no_pii/code/python/*.npy to reset cache + skip_warmup: true + dataloader_type: single # cyclic + validation_drop_last: true # Set to false if the last partial validation samples is to be consumed + eod_mask_loss: false # Mask loss for the end of document tokens + no_seqlen_plus_one_input_tokens: false # Set to true to disable fetching (sequence length + 1) input tokens, instead get (sequence length) input tokens and mask the last token + pad_samples_to_global_batch_size: false # Set to true if you want to pad the last partial batch with -1's to equal global batch size diff --git a/scripts/generate.py b/scripts/generate.py new file mode 100644 index 00000000..e5197226 --- /dev/null +++ b/scripts/generate.py @@ -0,0 +1,144 @@ +""" Example of generation with a pretrained Llama model. + +USE_FAST=1 CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 scripts/generate.py --pp 2 --tp 2 --model_name huggyllama/llama-7b --ckpt-path /fsx/nouamane/projects/brrr/pretrained/llama-2-7b +""" +import argparse +from pathlib import Path + +import torch +from nanotron.config import ParallelismArgs +from nanotron.core import distributed as dist +from nanotron.core import logging +from nanotron.core.logging import log_rank +from nanotron.core.parallelism.parameters import sanity_check +from nanotron.core.parallelism.pipeline_parallelism.engine import ( + OneForwardOneBackwardPipelineEngine, +) +from nanotron.core.parallelism.pipeline_parallelism.tensor_pointer import TensorPointer +from nanotron.core.parallelism.tensor_parallelism.enum import TensorParallelLinearMode +from nanotron.core.process_groups_initializer import get_process_groups +from nanotron.core.random import ( + set_random_seed, +) +from nanotron.core.serialize import ( + load_weights, +) +from nanotron.generation import GenerationConfig, GenerationInput, TokenizerConfig, greedy_search +from nanotron.trainer import CONFIG_TO_MODEL_CLASS, DistributedTrainer, mark_tied_parameters +from transformers import AutoConfig, AutoTokenizer + +logger = logging.get_logger(__name__) + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_name", type=str, required=True, help="Model name") + parser.add_argument("--ckpt-path", type=Path, required=True, help="Checkpoint path") + parser.add_argument("--dp", type=int, default=1) + parser.add_argument("--pp", type=int, default=2) + parser.add_argument("--tp", type=int, default=1) + return parser.parse_args() + + +def main(): + args = get_args() + checkpoint_path = args.ckpt_path + parallel_config = ParallelismArgs( + dp=args.dp, + pp=args.pp, + tp=args.tp, + pp_engine=OneForwardOneBackwardPipelineEngine(), + tp_mode=TensorParallelLinearMode.ALL_REDUCE, + recompute_granularity=None, + tp_linear_async_communication=False, + ) + dtype = torch.bfloat16 + + # Set random states + set_random_seed(42) + + # Initialise all process groups + dpg = get_process_groups( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + + model_name = args.model_name + model_config: AutoConfig = AutoConfig.from_pretrained(model_name) + # model_config.num_hidden_layers = 1 + + model_config_cls = model_config.__class__.__name__ + if model_config_cls not in CONFIG_TO_MODEL_CLASS: + raise ValueError( + f"Unsupported model config {model_config_cls}. Only {CONFIG_TO_MODEL_CLASS.keys()} are supported" + ) + + model = DistributedTrainer.build_model( + model_builder=lambda: CONFIG_TO_MODEL_CLASS[model_config_cls]( + config=model_config, + dpg=dpg, + parallel_config=parallel_config, + random_states=None, + ), + model_config=model_config, + dtype=dtype, + dpg=dpg, + ) + + # Mark some parameters as tied + # TODO @nouamane: this is only needed for training, can we just mark params as NanotronParameter instead? + mark_tied_parameters(model=model, dpg=dpg, parallel_config=parallel_config) + + # Sanity check model + sanity_check(root_module=model) + + # Load checkpoint + log_rank( + f"Loading checkpoint from {checkpoint_path}:", + logger=logger, + level=logging.INFO, + rank=0, + ) + load_weights(model=model, dpg=dpg, root_folder=checkpoint_path) + + model.eval() + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + dummy_inputs = [ + "This film was probably inspired by Godzilla", + "If the crew behind 'Zombieland' had a", + ] + + lm_without_head = model.transformer if model_config_cls == "FalconConfig" else model.model + outputs = greedy_search( + input_iter=(GenerationInput(text=text) for text in dummy_inputs), + tokenizer=tokenizer, + # TODO @thomasw21: From ModelWithLoss extract the model. + model=lm_without_head, + # TODO @thomasw21: Figure out how to pass p2p. + p2p=lm_without_head.p2p, + dpg=dpg, + generation_config=GenerationConfig(max_new_tokens=40, max_micro_batch_size=8), + tokenizer_config=TokenizerConfig(max_input_length=8), + ) + dist.barrier() + for output in outputs: + input_ids = output.input_ids + generated_ids = output.generation_ids + if isinstance(input_ids, TensorPointer): + assert isinstance(generated_ids, TensorPointer) + continue + assert isinstance(generated_ids, torch.Tensor) + print( + { + "input": tokenizer.decode(input_ids, clean_up_tokenization_spaces=False), + "generation": tokenizer.decode(generated_ids, clean_up_tokenization_spaces=False), + } + ) + dist.barrier() + + +if __name__ == "__main__": + main() diff --git a/scripts/train.py b/scripts/train.py new file mode 100644 index 00000000..afa958a3 --- /dev/null +++ b/scripts/train.py @@ -0,0 +1,145 @@ +""" + +You can run using command: +``` +USE_FAST=1 CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=8 scripts/train.py --config-file configs/config.yaml +USE_FAST=1 CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=8 scripts/train.py --config-file configs/config_correctness.yaml +``` +""" +import argparse +from typing import Dict, Iterator, Union + +import torch +from nanotron.config import ( + Config, + PretrainDatasetsArgs, + PretrainNemoArgs, + get_args_from_path, +) +from nanotron.core import logging +from nanotron.core.logging import log_rank +from nanotron.core.parallelism.pipeline_parallelism.tensor_pointer import TensorPointer +from nanotron.core.utils import ( + main_rank_first, +) +from nanotron.dataloaders.dataloader import ( + clm_process, + dummy_infinite_data_generator, + get_datasets, + get_train_dataloader, +) +from nanotron.dataloaders.nemo import get_nemo_dataloader, get_nemo_datasets +from nanotron.trainer import DistributedTrainer +from torch.nn.parallel import DistributedDataParallel +from transformers import AutoTokenizer + +logger = logging.get_logger(__name__) + + +def get_dataloader(trainer) -> Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]: + # Prepare dataloader + tokenizer = AutoTokenizer.from_pretrained(trainer.config.model.hf_model_name) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + if isinstance(trainer.model, DistributedDataParallel): + input_pp_rank = trainer.model.module.input_pp_rank + output_pp_rank = trainer.model.module.output_pp_rank + else: + input_pp_rank = trainer.model.input_pp_rank + output_pp_rank = trainer.model.output_pp_rank + + if config.data.dataset is None: + dataloader = dummy_infinite_data_generator( + micro_batch_size=trainer.micro_batch_size, + sequence_length=trainer.sequence_length, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + vocab_size=trainer.model_config.vocab_size, + seed=trainer.config.data.seed, + dpg=trainer.dpg, + )() + elif isinstance(config.data.dataset, PretrainNemoArgs): + log_rank("Using Nemo Dataloader", logger=logger, level=logging.INFO, rank=0) + + train_dataset, valid_dataset, test_datasets = get_nemo_datasets( + config=config.data.dataset, + global_batch_size=trainer.global_batch_size, + sequence_length=config.tokens.sequence_length, + train_steps=config.tokens.train_steps, + limit_val_batches=config.tokens.limit_val_batches, + val_check_interval=config.tokens.val_check_interval, + test_iters=config.tokens.limit_test_batches, + seed=config.data.seed, + dpg=trainer.dpg, + ) + dataloader = get_nemo_dataloader( + dataset=train_dataset, + sequence_length=trainer.sequence_length, + micro_batch_size=trainer.micro_batch_size, + global_batch_size=trainer.global_batch_size, + num_workers=config.data.num_loading_workers, + cfg=config.data.dataset, + consumed_samples=trainer.consumed_train_samples, + dpg=trainer.dpg, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + dataloader_drop_last=True, + ) + elif isinstance(config.data.dataset, PretrainDatasetsArgs): + log_rank("Using `datasets` library", logger=logger, level=logging.INFO, rank=0) + + with main_rank_first(trainer.dpg.world_pg): + # 1st device processes dataset and cache it, then other devices load from cache + # TODO @nouamanetazi: this may timeout before 1st device finishes processing dataset. Can we have a ctxmanager to modify timeout? + # TODO: generalise to include for validation/test splits + raw_dataset = get_datasets( + dataset_mixer=config.data.dataset.hf_dataset_mixer, splits=config.data.dataset.hf_dataset_splits + )["train"] + tokenizer = AutoTokenizer.from_pretrained(trainer.config.model.hf_model_name) + + train_dataset = clm_process( + raw_dataset=raw_dataset, + tokenizer=tokenizer, + text_column_name=config.data.dataset.text_column_name, + dataset_processing_num_proc_per_process=config.data.dataset.dataset_processing_num_proc_per_process, + dataset_overwrite_cache=config.data.dataset.dataset_overwrite_cache, + sequence_length=trainer.sequence_length, + ) + dataloader = get_train_dataloader( + train_dataset=train_dataset, + sequence_length=trainer.sequence_length, + dpg=trainer.dpg, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + micro_batch_size=trainer.micro_batch_size, + consumed_train_samples=trainer.consumed_train_samples, + dataloader_num_workers=config.data.num_loading_workers, + seed_worker=config.data.seed, + dataloader_drop_last=True, + ) + # Check if we have enough samples for train_steps + assert ( + config.tokens.train_steps - trainer.start_iteration_step + ) * trainer.global_batch_size // trainer.dpg.dp_pg.size() < len( + dataloader + ), f"Dataset is too small for steps ({len(dataloader)} < {(config.tokens.train_steps - trainer.start_iteration_step) * trainer.global_batch_size // trainer.dpg.dp_pg.size()}), Try train_steps<={len(dataloader) * trainer.dpg.dp_pg.size() // trainer.global_batch_size + trainer.start_iteration_step}" + + else: # TODO: other datasets + raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {trainer.config.data.dataset}") + + return dataloader + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML config file") + return parser.parse_args() + + +if __name__ == "__main__": + config_file = get_args().config_file + config: Config = get_args_from_path(config_file) + trainer = DistributedTrainer(config=config) + + dataloader = get_dataloader(trainer) + trainer.train(dataloader=dataloader) diff --git a/src/nanotron/core/gradient_accumulator.py b/src/nanotron/core/gradient_accumulator.py index 164bffe5..129fe3b1 100644 --- a/src/nanotron/core/gradient_accumulator.py +++ b/src/nanotron/core/gradient_accumulator.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict from contextlib import contextmanager -from typing import Callable, Dict, Iterator, Optional, Tuple +from typing import Callable, Dict, Iterable, Iterator, Optional, Tuple import torch from torch.distributed import GradBucket @@ -58,7 +58,7 @@ class FP32GradientAccumulator(GradientAccumulator): def __init__( self, named_parameters: Iterator[Tuple[str, NanotronParameter]], - grad_buckets_named_params: Optional[Iterator[Tuple[str, NanotronParameter]]] = None, + grad_buckets_named_params: Optional[Iterable[Tuple[str, NanotronParameter]]] = None, ): """Create a gradient accumulator that will accumulate gradients in fp32. diff --git a/src/nanotron/core/optimizer/optimizer_from_gradient_accumulator.py b/src/nanotron/core/optimizer/optimizer_from_gradient_accumulator.py index 6b3f5d9b..008433f1 100644 --- a/src/nanotron/core/optimizer/optimizer_from_gradient_accumulator.py +++ b/src/nanotron/core/optimizer/optimizer_from_gradient_accumulator.py @@ -1,5 +1,5 @@ from functools import cache -from typing import Any, Callable, Dict, Iterable, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Set, Tuple, Union import torch @@ -12,8 +12,8 @@ class OptimizerFromGradientAccumulator(InheritFromOtherOptimizer): def __init__( self, - gradient_accumulator_builder: Callable[[Iterable[Tuple[str, NanotronParameter]]], GradientAccumulator], - named_params_or_groups: Iterable[Union[Tuple[str, torch.Tensor], Dict[str, Any]]], + gradient_accumulator_builder: Callable[[Iterator[Tuple[str, NanotronParameter]]], GradientAccumulator], + named_params_or_groups: Iterator[Union[Tuple[str, torch.Tensor], Dict[str, Any]]], optimizer_builder: Callable[[Iterable[Dict[str, Any]]], BaseOptimizer], ): named_param_groups = list(named_params_or_groups) diff --git a/src/nanotron/core/parallelism/data_parallelism/utils.py b/src/nanotron/core/parallelism/data_parallelism/utils.py index 6d1c0b2a..6318aed4 100644 --- a/src/nanotron/core/parallelism/data_parallelism/utils.py +++ b/src/nanotron/core/parallelism/data_parallelism/utils.py @@ -2,10 +2,9 @@ from typing import Optional import torch -from torch import nn - from nanotron.core import distributed as dist from nanotron.core.gradient_accumulator import GradientAccumulator +from torch import nn @contextmanager diff --git a/src/nanotron/core/parallelism/parameters.py b/src/nanotron/core/parallelism/parameters.py index 2c9a9993..4549004f 100644 --- a/src/nanotron/core/parallelism/parameters.py +++ b/src/nanotron/core/parallelism/parameters.py @@ -135,11 +135,15 @@ def mark_as_tied( ) def get_tied_info(self) -> TiedInfo: - return getattr(self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME)[self.NANOTRON_PARAMETER_METADATA_TIED_KEY] + return getattr(self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME)[ + self.NANOTRON_PARAMETER_METADATA_TIED_KEY + ] @property def is_tied(self) -> bool: - return self.NANOTRON_PARAMETER_METADATA_TIED_KEY in getattr(self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME) + return self.NANOTRON_PARAMETER_METADATA_TIED_KEY in getattr( + self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME + ) def mark_as_sharded( self, @@ -157,11 +161,15 @@ def mark_as_sharded( ) def get_sharded_info(self) -> ShardedInfo: - return getattr(self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME)[self.NANOTRON_PARAMETER_METADATA_SHARDED_KEY] + return getattr(self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME)[ + self.NANOTRON_PARAMETER_METADATA_SHARDED_KEY + ] @property def is_sharded(self) -> bool: - return self.NANOTRON_PARAMETER_METADATA_SHARDED_KEY in getattr(self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME) + return self.NANOTRON_PARAMETER_METADATA_SHARDED_KEY in getattr( + self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME + ) def sanity_check(root_module: nn.Module): diff --git a/src/nanotron/core/parallelism/pipeline_parallelism/block.py b/src/nanotron/core/parallelism/pipeline_parallelism/block.py index 2f8d399a..5c8aecab 100644 --- a/src/nanotron/core/parallelism/pipeline_parallelism/block.py +++ b/src/nanotron/core/parallelism/pipeline_parallelism/block.py @@ -1,8 +1,6 @@ from typing import Any, Callable, Dict, Optional, Set, Tuple, Union import torch -from torch import nn - from nanotron.core import distributed as dist from nanotron.core.parallelism.pipeline_parallelism.functional import ( recv_from_pipeline_state_buffer, @@ -11,6 +9,7 @@ from nanotron.core.parallelism.pipeline_parallelism.p2p import P2P, BatchTensorSendRecvState from nanotron.core.parallelism.pipeline_parallelism.state import PipelineBatchState, PipelineTrainBatchState from nanotron.core.parallelism.pipeline_parallelism.tensor_pointer import TensorPointer +from torch import nn class PipelineBlock(nn.Module): diff --git a/src/nanotron/core/parallelism/pipeline_parallelism/context_manager.py b/src/nanotron/core/parallelism/pipeline_parallelism/context_manager.py index ba6526f4..ce638180 100644 --- a/src/nanotron/core/parallelism/pipeline_parallelism/context_manager.py +++ b/src/nanotron/core/parallelism/pipeline_parallelism/context_manager.py @@ -1,9 +1,8 @@ from contextlib import contextmanager -from torch import nn as torch_nn - from nanotron.core.parallelism.pipeline_parallelism.block import PipelineBlock from nanotron.core.parallelism.pipeline_parallelism.state import PipelineBatchState +from torch import nn as torch_nn @contextmanager diff --git a/src/nanotron/core/parallelism/pipeline_parallelism/engine.py b/src/nanotron/core/parallelism/pipeline_parallelism/engine.py index 4adf43fe..35948ab9 100644 --- a/src/nanotron/core/parallelism/pipeline_parallelism/engine.py +++ b/src/nanotron/core/parallelism/pipeline_parallelism/engine.py @@ -2,9 +2,6 @@ from typing import Dict, Iterable, Optional, Union import torch -from torch import nn as torch_nn -from torch.nn.parallel import DistributedDataParallel - from nanotron.core import distributed as dist from nanotron.core import logging from nanotron.core.distributed import ProcessGroup @@ -15,6 +12,8 @@ from nanotron.core.parallelism.pipeline_parallelism.state import PipelineTrainBatchState from nanotron.core.parallelism.pipeline_parallelism.tensor_pointer import TensorPointer from nanotron.core.utils import ContextManagers +from torch import nn as torch_nn +from torch.nn.parallel import DistributedDataParallel logger = logging.get_logger(__name__) diff --git a/src/nanotron/core/parallelism/pipeline_parallelism/functional.py b/src/nanotron/core/parallelism/pipeline_parallelism/functional.py index 9750c8c2..8192d01e 100644 --- a/src/nanotron/core/parallelism/pipeline_parallelism/functional.py +++ b/src/nanotron/core/parallelism/pipeline_parallelism/functional.py @@ -1,5 +1,4 @@ import torch - from nanotron.core import logging from nanotron.core.parallelism.pipeline_parallelism.p2p import P2P from nanotron.core.parallelism.pipeline_parallelism.state import PipelineBatchState diff --git a/src/nanotron/core/parallelism/pipeline_parallelism/p2p.py b/src/nanotron/core/parallelism/pipeline_parallelism/p2p.py index 85ea7d33..65e475f7 100644 --- a/src/nanotron/core/parallelism/pipeline_parallelism/p2p.py +++ b/src/nanotron/core/parallelism/pipeline_parallelism/p2p.py @@ -2,7 +2,6 @@ from typing import List, Sequence, Tuple import torch - from nanotron.core import distributed as dist from nanotron.core import logging from nanotron.core.tensor_init import tensor_from_untyped_storage diff --git a/src/nanotron/core/parallelism/pipeline_parallelism/state.py b/src/nanotron/core/parallelism/pipeline_parallelism/state.py index 30e141a4..fd94a823 100644 --- a/src/nanotron/core/parallelism/pipeline_parallelism/state.py +++ b/src/nanotron/core/parallelism/pipeline_parallelism/state.py @@ -4,7 +4,6 @@ from typing import List import torch - from nanotron.core import distributed as dist from nanotron.core import logging from nanotron.core.logging import log_rank diff --git a/src/nanotron/core/parallelism/pipeline_parallelism/utils.py b/src/nanotron/core/parallelism/pipeline_parallelism/utils.py index 0a3f109b..1e237077 100644 --- a/src/nanotron/core/parallelism/pipeline_parallelism/utils.py +++ b/src/nanotron/core/parallelism/pipeline_parallelism/utils.py @@ -1,6 +1,5 @@ -from torch import nn - from nanotron.core.parallelism.pipeline_parallelism.block import PipelineBlock +from torch import nn def get_pp_rank_of(target: str, module: nn.Module): diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index a1a6c226..3d0dbb94 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -44,6 +44,7 @@ get_current_random_state, get_synced_random_state, ) +from nanotron.models import NanotronModel logger = logging.get_logger(__name__) @@ -184,24 +185,8 @@ def lr_lambda(current_step: int): def init_optimizer_and_grad_accumulator( model: nn.Module, optimizer_args: OptimizerArgs, dpg: DistributedProcessGroups ) -> Tuple[BaseOptimizer, GradientAccumulator]: - # Normalize DDP - normalized_model = model.module if isinstance(model, DistributedDataParallel) else model - - module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in normalized_model.named_modules()} - # Fix the root_model - root_model_id = id(normalized_model) - module_id_to_prefix[root_model_id] = "" - - # named parameters - named_parameters = [ - ( - param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) - if param.is_tied - else name, - param, - ) - for name, param in normalized_model.named_parameters() - ] + unwrapped_model: NanotronModel = model.module if isinstance(model, DistributedDataParallel) else model + named_parameters = unwrapped_model.get_named_params_with_tied() # Basic optimizer builder def basic_optimizer_builder(named_param_groups): @@ -283,14 +268,7 @@ def grad_optimizer_builder(named_param_groups): state=FP32GradBucketManager( dp_pg=dpg.dp_pg, accumulator=grad_accumulator, - param_id_to_name={ - id(param): param.get_tied_info().get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - if param.is_tied - else name - for name, param in normalized_model.named_parameters() - }, + param_id_to_name={id(param): name for name, param in named_parameters}, ), hook=get_fp32_accum_hook( reduce_scatter=optimizer.inherit_from(ZeroDistributedOptimizer), reduce_op=dist.ReduceOp.AVG diff --git a/src/nanotron/models/base_model.py b/src/nanotron/models/base_model.py index 67739e37..d2d095a7 100644 --- a/src/nanotron/models/base_model.py +++ b/src/nanotron/models/base_model.py @@ -1,5 +1,5 @@ from abc import ABCMeta, abstractmethod -from typing import Optional +from typing import Iterable, Iterator, Optional, Tuple from torch import nn from transformers import AutoConfig @@ -7,6 +7,7 @@ from nanotron.core import logging from nanotron.core.distributed import ProcessGroup from nanotron.core.logging import log_rank +from nanotron.core.parallelism.parameters import NanotronParameter from nanotron.core.parallelism.pipeline_parallelism.block import PipelineBlock from nanotron.core.process_groups_initializer import DistributedProcessGroups @@ -27,6 +28,10 @@ def __init__(self, *args, **kwargs) -> None: self.input_pp_rank: int self.output_pp_rank: int + # Useful mapping to get param names + self.module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in self.named_modules()} + self.module_id_to_prefix[id(self)] = "" + @abstractmethod def init_model_randomly(self, init_method, scaled_init_method): ... @@ -44,3 +49,22 @@ def log_modules(self, level: int = logging.DEBUG, group: Optional[ProcessGroup] group=group, rank=rank, ) + + def named_parameters( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, NanotronParameter]]: + return super().named_parameters(prefix, recurse, remove_duplicate) + + def get_named_params_with_tied(self) -> Iterable[Tuple[str, NanotronParameter]]: + named_parameters = [ + ( + param.get_tied_info().get_full_name_from_module_id_to_prefix( + module_id_to_prefix=self.module_id_to_prefix + ) + if param.is_tied + else name, + param, + ) + for name, param in self.named_parameters() + ] + return named_parameters diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 51521839..aff468f3 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -6,7 +6,7 @@ import sys import time from pprint import pformat -from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union +from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union import numpy as np import torch @@ -126,17 +126,7 @@ def __init__(self, config: Config): ) # Do a first NCCL sync to warmup and try to avoid Timeout after model/data loading - test_tensor = torch.tensor([dist.get_rank(self.dpg.world_pg)], device=torch.device("cuda")) - test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(self.dpg.world_pg.size())] - dist.all_gather(test_tensor_list, test_tensor, group=self.dpg.world_pg, async_op=False) - dist.barrier() - log_rank( - f"Test NCCL sync for ranks {[t.item() for t in test_tensor_list]}", - logger=logger, - level=logging.INFO, - group=self.dpg.dp_pg, - rank=0, - ) + self.run_nccl_test() # Set random states set_random_seed(self.config.model.seed) @@ -238,12 +228,7 @@ def train(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer # TODO @nouamanetazi: refactor this # Useful mapping - self.normalized_model = self.model.module if isinstance(self.model, DistributedDataParallel) else self.model - self.module_id_to_prefix = { - id(module): f"{module_name}." for module_name, module in self.normalized_model.named_modules() - } - # Fix the root_model - self.module_id_to_prefix[id(self.normalized_model)] = "" + self.unwrapped_model = self.model.module if isinstance(self.model, DistributedDataParallel) else self.model prof = get_profiler(config=self.config) with self.tb_context as tb_writer: @@ -318,7 +303,7 @@ def training_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Tenso # Sync tied weights # TODO @nouamane: Put this in hooks so we can overlap communication with gradient computation on the last backward pass. sync_tied_weights_gradients( - module=self.normalized_model, + module=self.unwrapped_model, dpg=self.dpg, grad_accumulator=self.grad_accumulator, ) @@ -337,15 +322,8 @@ def training_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Tenso if self.config.optimizer.clip_grad is not None: # Normalize DDP named_parameters = [ - ( - param.get_tied_info().get_full_name_from_module_id_to_prefix( - module_id_to_prefix=self.module_id_to_prefix - ) - if param.is_tied - else name, - param, - ) - for name, param in self.normalized_model.named_parameters() + (name, param) + for name, param in self.unwrapped_model.get_named_params_with_tied() if param.requires_grad ] # TODO @nouamane: we need to split `world_rank_matrix` along PP axis, to separate ref from active model @@ -477,8 +455,6 @@ def build_model( pipeline_blocks = [module for name, module in model.named_modules() if isinstance(module, PipelineBlock)] # "cuda" is already defaulted for each process to it's own cuda device with init_on_device_and_dtype(device=device, dtype=dtype): - # TODO: https://github.com/huggingface/nanotron/issues/65 - # Balance compute across PP blocks block_compute_costs = model.get_block_compute_costs() block_cumulative_costs = np.cumsum( @@ -574,7 +550,7 @@ def _init_model( model_config: AutoConfig, model_builder: Callable[[], NanotronModel], target_pp_ranks: Optional[List[int]] = None, - ) -> Tuple[NanotronModel]: + ) -> NanotronModel: config = self.config dpg = self.dpg @@ -731,8 +707,8 @@ def before_tbi_sanity_checks(self) -> None: # SANITY CHECK: Tied weights are synchronized tied_params_list = sorted( get_tied_id_to_param( - parameters=self.normalized_model.parameters(), - root_module=self.normalized_model, + parameters=self.unwrapped_model.parameters(), + root_module=self.unwrapped_model, ).items(), key=lambda x: x[0], ) @@ -761,14 +737,14 @@ def after_tbi_sanity_checks(self) -> None: # SANITY CHECK: Check that gradient flow on the entire model # SANITY CHECK: Check that all parameters that required gradients, have actually a gradient # SANITY CHECK: Check for nan/inf - for name, param in self.normalized_model.named_parameters(): + for name, param in self.unwrapped_model.named_parameters(): if not param.requires_grad: continue if param.is_tied: tied_info = param.get_tied_info() name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=self.module_id_to_prefix + module_id_to_prefix=self.unwrapped_model.module_id_to_prefix ) if self.grad_accumulator is not None: @@ -790,7 +766,7 @@ def before_optim_step_sanity_checks(self) -> None: # SANITY CHECK: Test tied weights gradients are synchronized for (name, group_ranks), param in sorted( get_tied_id_to_param( - parameters=self.normalized_model.parameters(), root_module=self.normalized_model + parameters=self.unwrapped_model.parameters(), root_module=self.unwrapped_model ).items(), key=lambda x: x[0], ): @@ -811,14 +787,14 @@ def before_optim_step_sanity_checks(self) -> None: ) # SANITY CHECK: Test gradients are synchronized across DP - for name, param in sorted(self.normalized_model.named_parameters(), key=lambda x: x[0]): + for name, param in sorted(self.unwrapped_model.named_parameters(), key=lambda x: x[0]): if not param.requires_grad: continue if param.is_tied: tied_info = param.get_tied_info() name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=self.module_id_to_prefix + module_id_to_prefix=self.unwrapped_model.module_id_to_prefix ) if self.grad_accumulator is not None: @@ -842,7 +818,7 @@ def before_optim_step_sanity_checks(self) -> None: # SANITY CHECK: Tied weights are synchronized tied_params_list = sorted( get_tied_id_to_param( - parameters=self.normalized_model.parameters(), root_module=self.normalized_model + parameters=self.unwrapped_model.parameters(), root_module=self.unwrapped_model ).items(), key=lambda x: x[0], ) @@ -869,7 +845,22 @@ def after_optim_step_sanity_checks(self) -> None: level=logging.ERROR, ) + def run_nccl_test(self) -> None: + """NCCL sanity check""" + test_tensor = torch.tensor([dist.get_rank(self.dpg.world_pg)], device=torch.device("cuda")) + test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(self.dpg.world_pg.size())] + dist.all_gather(test_tensor_list, test_tensor, group=self.dpg.world_pg, async_op=False) + dist.barrier() + log_rank( + f"Test NCCL sync for ranks {[t.item() for t in test_tensor_list]}", + logger=logger, + level=logging.INFO, + group=self.dpg.dp_pg, + rank=0, + ) + +# TODO @nouamane: move to NanotronModel like tflops because it depends on the model def mark_tied_parameters( model: NanotronModel, dpg: DistributedProcessGroups, parallel_config: Optional[ParallelismArgs] = None ): diff --git a/tests/helpers/distributed_tensor.py b/tests/helpers/distributed_tensor.py index b7594983..5928d6c0 100644 --- a/tests/helpers/distributed_tensor.py +++ b/tests/helpers/distributed_tensor.py @@ -1,5 +1,4 @@ import torch - from nanotron.core import distributed as dist from nanotron.core.distributed import ProcessGroup, get_global_rank diff --git a/tests/helpers/dummy.py b/tests/helpers/dummy.py index 17ea29a3..d87c65fb 100644 --- a/tests/helpers/dummy.py +++ b/tests/helpers/dummy.py @@ -2,9 +2,6 @@ from typing import Union import torch -from torch import nn -from torch.nn.parallel import DistributedDataParallel - from nanotron.core import distributed as dist from nanotron.core.dataclass import DistributedProcessGroups from nanotron.core.optimizer.base import BaseOptimizer @@ -16,6 +13,8 @@ from nanotron.core.parallelism.pipeline_parallelism.tensor_pointer import TensorPointer from nanotron.core.parallelism.tied_parameters import tie_parameters from nanotron.core.utils import init_on_device_and_dtype +from torch import nn +from torch.nn.parallel import DistributedDataParallel class DummyModel(nn.Module): diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 52493bc1..1657d298 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -4,9 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple import torch.cuda -from torch.distributed.launcher import elastic_launch - from nanotron.core.process_groups_initializer import get_process_groups +from torch.distributed.launcher import elastic_launch def available_gpus(): diff --git a/tests/test_checkpointing.py b/tests/test_checkpointing.py index 72c28232..bec212a1 100644 --- a/tests/test_checkpointing.py +++ b/tests/test_checkpointing.py @@ -1,10 +1,9 @@ from typing import Union import torch -from torch import nn - from nanotron.core.parallelism.pipeline_parallelism.tensor_pointer import TensorPointer from nanotron.core.utils import checkpoint_method +from torch import nn class CheckpointedModel(nn.Module): diff --git a/tests/test_clip_grads.py b/tests/test_clip_grads.py index d5dd80ce..d0f1ecc9 100644 --- a/tests/test_clip_grads.py +++ b/tests/test_clip_grads.py @@ -5,8 +5,6 @@ import torch from helpers.dummy import DummyModel, dummy_infinite_data_loader from helpers.utils import available_gpus, init_distributed -from torch import nn - from nanotron.clip_grads import clip_grad_norm from nanotron.core import distributed as dist from nanotron.core.gradient_accumulator import ( @@ -28,6 +26,7 @@ ) from nanotron.core.process_groups_initializer import DistributedProcessGroups from nanotron.core.utils import assert_tensor_synced_across_pg, init_on_device_and_dtype +from torch import nn @pytest.mark.skipif(available_gpus() < 2, reason="test_clip_grads_with_pp requires at least 2 gpus") diff --git a/tests/test_data_parallel.py b/tests/test_data_parallel.py index cbff8d22..f11a2d53 100644 --- a/tests/test_data_parallel.py +++ b/tests/test_data_parallel.py @@ -4,14 +4,13 @@ import torch from helpers.exception import assert_fail_except_rank_with from helpers.utils import available_gpus, init_distributed -from torch import nn -from torch.distributed import GradBucket - from nanotron.core import distributed as dist from nanotron.core.parallelism.data_parallelism.utils import ddp_trigger_sync_in_bwd from nanotron.core.parallelism.parameters import NanotronParameter from nanotron.core.process_groups_initializer import DistributedProcessGroups from nanotron.core.utils import assert_tensor_synced_across_pg +from torch import nn +from torch.distributed import GradBucket @pytest.mark.skipif(available_gpus() < 2, reason="Testing test_ddp_with_afab requires at least 2 gpus") diff --git a/tests/test_p2p.py b/tests/test_p2p.py index 4263e296..9a6d34e2 100644 --- a/tests/test_p2p.py +++ b/tests/test_p2p.py @@ -4,7 +4,6 @@ import torch from helpers.exception import assert_fail_with from helpers.utils import available_gpus, init_distributed - from nanotron.core import distributed as dist from nanotron.core.dataclass import DistributedProcessGroups from nanotron.core.parallelism.pipeline_parallelism.p2p import P2P diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 97aefe8c..0c6fcac0 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -1,9 +1,8 @@ import torch from helpers.exception import assert_fail_with -from torch import nn - from nanotron.core.parallelism.parameters import NanotronParameter from nanotron.core.utils import DTypeInvariantTensor, init_on_device_and_dtype +from torch import nn def test_nanotron_parameter_does_not_override_some_parameter_variable(): diff --git a/tests/test_parameters_accumulate_gradient_in_fp32.py b/tests/test_parameters_accumulate_gradient_in_fp32.py index 1e58fd78..3773ac0e 100644 --- a/tests/test_parameters_accumulate_gradient_in_fp32.py +++ b/tests/test_parameters_accumulate_gradient_in_fp32.py @@ -1,13 +1,11 @@ import copy +import nanotron.core.distributed as dist import pytest import torch from helpers.dummy import DummyModel, dummy_infinite_data_loader from helpers.exception import assert_fail_except_rank_with, timeout_after from helpers.utils import available_gpus, init_distributed -from torch import nn - -import nanotron.core.distributed as dist from nanotron.core.dataclass import DistributedProcessGroups from nanotron.core.gradient_accumulator import FP32GradBucketManager, FP32GradientAccumulator, get_fp32_accum_hook from nanotron.core.optimizer import ZeroDistributedOptimizer @@ -30,6 +28,7 @@ tie_parameters, ) from nanotron.core.utils import ContextManagers, assert_tensor_synced_across_pg, init_on_device_and_dtype +from torch import nn @pytest.mark.parametrize("half_precision", [torch.float16, torch.bfloat16]) diff --git a/tests/test_pipeline_parallel.py b/tests/test_pipeline_parallel.py index a183060f..8392d59f 100644 --- a/tests/test_pipeline_parallel.py +++ b/tests/test_pipeline_parallel.py @@ -4,9 +4,6 @@ import torch from helpers.dummy import DummyModel, dummy_infinite_data_loader from helpers.utils import available_gpus, init_distributed -from torch import nn -from torch.nn import functional as F - from nanotron.core import distributed as dist from nanotron.core.parallelism.pipeline_parallelism.block import PipelineBlock from nanotron.core.parallelism.pipeline_parallelism.engine import ( @@ -18,6 +15,8 @@ from nanotron.core.parallelism.pipeline_parallelism.tensor_pointer import TensorPointer from nanotron.core.process_groups_initializer import DistributedProcessGroups from nanotron.core.utils import init_on_device_and_dtype +from torch import nn +from torch.nn import functional as F @pytest.mark.skipif(available_gpus() < 2, reason="Testing build_and_set_rank requires at least 2 gpus") diff --git a/tests/test_random_state.py b/tests/test_random_state.py index fcdd4477..8da311a0 100644 --- a/tests/test_random_state.py +++ b/tests/test_random_state.py @@ -1,7 +1,6 @@ import pytest import torch from helpers.utils import available_gpus, init_distributed - from nanotron.core import distributed as dist from nanotron.core.dataclass import DistributedProcessGroups, RandomStates from nanotron.core.random import ( diff --git a/tests/test_serialize.py b/tests/test_serialize.py index 1273f97b..e823d1b4 100644 --- a/tests/test_serialize.py +++ b/tests/test_serialize.py @@ -8,8 +8,6 @@ init_distributed, is_dict_equal, ) -from torch.nn.parallel import DistributedDataParallel - from nanotron.core import distributed as dist from nanotron.core.dataclass import DistributedProcessGroups, RandomStates from nanotron.core.gradient_accumulator import FP32GradientAccumulator @@ -34,6 +32,7 @@ ) from nanotron.core.serialize.constants import CHECKPOINT_VERSION from nanotron.core.serialize.meta import TensorMetadataV2 +from torch.nn.parallel import DistributedDataParallel def test_save_and_load_with_changed_topolgy(): @@ -362,11 +361,11 @@ def _test_save_optimizer_with_additional_state_dict_keys(dpg: DistributedProcess if isinstance(model, DistributedDataParallel): # Remove the annoying "module." prefix - normalized_model = model.module + unwrapped_model = model.module else: - normalized_model = model + unwrapped_model = model - named_parameters = list(normalized_model.named_parameters()) + named_parameters = list(unwrapped_model.named_parameters()) optimizer = OptimizerFromGradientAccumulator( gradient_accumulator_builder=lambda named_params: FP32GradientAccumulator(named_parameters=named_params), @@ -390,7 +389,7 @@ def _test_save_optimizer_with_additional_state_dict_keys(dpg: DistributedProcess model=model, pg=dpg.pp_pg, batch=[minibatch], grad_accumulator=grad_accumulator ) # Manually sync tied parameters - sync_tied_weights_gradients(module=normalized_model, dpg=dpg, grad_accumulator=grad_accumulator) + sync_tied_weights_gradients(module=unwrapped_model, dpg=dpg, grad_accumulator=grad_accumulator) # Optimizer steps optimizer.step() optimizer.zero_grad() diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index 881abb00..c97232e7 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -5,8 +5,6 @@ import pytest import torch from helpers.utils import available_gpus, init_distributed -from torch import nn as torch_nn - from nanotron.core import distributed as dist from nanotron.core.distributed import get_global_rank from nanotron.core.parallelism.tensor_parallelism.enum import TensorParallelLinearMode @@ -16,6 +14,7 @@ TensorParallelRowLinear, ) from nanotron.core.process_groups_initializer import DistributedProcessGroups +from torch import nn as torch_nn @pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, available_gpus() + 1)]) diff --git a/tests/test_tie_weights.py b/tests/test_tie_weights.py index 56c64371..91631ab1 100644 --- a/tests/test_tie_weights.py +++ b/tests/test_tie_weights.py @@ -2,8 +2,6 @@ from helpers.distributed_tensor import assert_tensor_equal_over_group from helpers.exception import assert_fail_with from helpers.utils import init_distributed -from torch import nn - from nanotron.core import distributed as dist from nanotron.core.dataclass import DistributedProcessGroups from nanotron.core.parallelism.parameters import NanotronParameter @@ -12,6 +10,7 @@ sync_tied_weights_gradients, tie_parameters, ) +from torch import nn def test_tie_weight_in_same_device(): diff --git a/tests/test_training.py b/tests/test_training.py new file mode 100644 index 00000000..a54a8de8 --- /dev/null +++ b/tests/test_training.py @@ -0,0 +1,72 @@ +"""Script to test correctness of training script by comparing loss value after 100th iteration with expected loss value + +```bash +python tests/test_training.py +``` +""" + +import atexit +import os +import re +import signal +import subprocess +import time + +EXPECTED_LOSS = 8e-03 +CONFIG_FILE = "configs/config_correctness.yaml" +TRAIN_SCRIPT = "scripts/train.py" +NUM_GPUS = 8 +CHECK_ITERATION = 100 + + +def exit_with_children(): + """Kill all children processes when this process exits""" + os.killpg(0, signal.SIGKILL) + + +def extract_loss(line): + """Extract loss value from the line""" + # extract loss value of the type | lm_loss: 7.087376E-03 | OR | lm_loss: 7.087376E+03 | + try: + return float(re.search(r"lm_loss: (\d+.\d+E[-+]?\d+)", line.decode("utf-8")).group(1)) + except AttributeError: + raise ValueError(f"Could not extract loss value from line: {line}") + + +if __name__ == "__main__": + cmd = f"USE_FAST=1 CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={NUM_GPUS} {TRAIN_SCRIPT} --config-file {CONFIG_FILE}" + start_time = time.time() + + os.setpgrp() # create new process group, become its leader + atexit.register(exit_with_children) # kill all children processes when this process exits + + # read logs in streaming fashion + for line in subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT).stdout: + print(line.decode("utf-8"), end="") + + # for all iterations >= 30, loss should be below 0.01 + if re.search(r"iteration: (\d+) / ", line.decode("utf-8")): + if int(re.search(r"iteration: (\d+) / ", line.decode("utf-8")).group(1)) >= 30: + loss = extract_loss(line) + if loss > 2e-02: + print("=" * 10, "TEST FAILED", "=" * 10) + print(f"Loss after 30th iteration is {loss} which is bigger than expected loss 0.01") + print(f"Time taken: {time.time() - start_time}") + exit(1) + + if re.search(rf"iteration: {CHECK_ITERATION} / ", line.decode("utf-8")): + loss = extract_loss(line) + if loss > EXPECTED_LOSS: + print("=" * 10, "TEST FAILED", "=" * 10) + print( + f"Loss after {CHECK_ITERATION}th iteration is {loss} which is bigger than expected loss {EXPECTED_LOSS}" + ) + print(f"Time taken: {time.time() - start_time:.2f}s") + exit(1) + else: + print("=" * 10, "TEST PASSED", "=" * 10) + print( + f"Loss after {CHECK_ITERATION}th iteration is {loss} which is smaller than expected loss {EXPECTED_LOSS}" + ) + print(f"Time taken: {time.time() - start_time:.2f}s") + exit(0) diff --git a/tests/test_zero.py b/tests/test_zero.py index b5a4acce..7cb0c6ba 100644 --- a/tests/test_zero.py +++ b/tests/test_zero.py @@ -6,9 +6,6 @@ from helpers.dummy import dummy_infinite_data_loader, init_dummy_model from helpers.exception import assert_fail_with from helpers.utils import available_gpus, init_distributed -from torch import nn as torch_nn -from torch.nn.parallel import DistributedDataParallel - from nanotron.core import distributed as dist from nanotron.core.dataclass import DistributedProcessGroups, RandomStates from nanotron.core.optimizer import NamedOptimizer, ZeroDistributedOptimizer @@ -21,6 +18,8 @@ from nanotron.core.parallelism.tensor_parallelism.enum import TensorParallelLinearMode from nanotron.core.parallelism.tied_parameters import sync_tied_weights_gradients from nanotron.core.random import branch_random_state, get_current_random_state, get_synced_random_state +from torch import nn as torch_nn +from torch.nn.parallel import DistributedDataParallel @pytest.mark.parametrize("tp,dp,pp", [pytest.param(1, i, 1) for i in range(1, available_gpus() + 1)])