Skip to content

Commit

Permalink
fix nan in fp8 mlp
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed May 28, 2024
1 parent 85cd79b commit 59295ad
Show file tree
Hide file tree
Showing 13 changed files with 458 additions and 45 deletions.
117 changes: 117 additions & 0 deletions examples/config_fp8_llama.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
checkpoints:
checkpoint_interval: 10
checkpoints_path: checkpoints
checkpoints_path_is_shared_file_system: false
# resume_checkpoint_path: checkpoints
save_initial_state: false
data_stages:
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: stas/openwebtext-10k
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: Stable Training Stage
start_training_step: 1
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: stas/openwebtext-10k
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: Annealing Phase
start_training_step: 10
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: debug
run: tiny_llama_%date_%jobid
seed: 42
step: null
lighteval: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
monitor_model_states: true
model:
ddp_bucket_cap_mb: 25
dtype: float8
init_method:
# std: 0.25 # sqrt(1/16)
# std: 0.125 # sqrt(1/64)
std: 0.04419417382415922 # sqrt(1/512)
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 1
eos_token_id: 2
# hidden_act: silu
hidden_act: gelu
hidden_size: 512
initializer_range: 0.02
intermediate_size: 2048
is_llama_config: true
max_position_embeddings: 256
num_attention_heads: 4
num_hidden_layers: 1
num_key_value_heads: 4
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
tie_word_embeddings: false
use_cache: true
vocab_size: 256
optimizer:
accumulate_grad_in_fp32: false
# clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.01
lr_decay_starting_step: null
lr_decay_steps: 13
lr_decay_style: linear
lr_warmup_steps: 2
lr_warmup_style: constant
min_decay_lr: 1.0e-05
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.999
adam_eps: 1.0e-08
name: adam
torch_adam_is_fused: true
weight_decay: 0.1
zero_stage: 0
parallelism:
dp: 1
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
tp: 2
# tp_linear_async_communication: true
# tp_mode: REDUCE_SCATTER

tp_linear_async_communication: false
tp_mode: ALL_REDUCE

profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: robot-test/dummy-tokenizer-wordlevel
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 1
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 2
sequence_length: 256
train_steps: 15
val_check_interval: -1
8 changes: 4 additions & 4 deletions examples/config_tiny_llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,20 @@ optimizer:
accumulate_grad_in_fp32: false
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
learning_rate: 0.001
lr_decay_starting_step: null
lr_decay_steps: 13
lr_decay_style: cosine
lr_decay_style: linear
lr_warmup_steps: 2
lr_warmup_style: linear
lr_warmup_style: constant
min_decay_lr: 1.0e-05
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
weight_decay: 0.1
zero_stage: 0
parallelism:
dp: 1
Expand Down
15 changes: 12 additions & 3 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ class BenchArgs:
class LoggingArgs:
"""
Arguments related to logging
monitor_model_states: whether to monitor the model states including the statistics
of activations, input gradients, output gradients and model weights during training.
# NOTE:
- You could use the `iteration_step_info_interval` to control the frequency of logging the model states.
"""
Expand Down Expand Up @@ -301,11 +301,20 @@ class AdamWOptimizerArgs:
name: str = "adamW"


@dataclass
class AdamOptimizerArgs:
adam_eps: float
adam_beta1: float
adam_beta2: float
torch_adam_is_fused: bool
name: str = "adam"


@dataclass
class OptimizerArgs:
"""Arguments related to the optimizer and learning rate"""

optimizer_factory: Union[SGDOptimizerArgs, AdamWOptimizerArgs]
optimizer_factory: Union[SGDOptimizerArgs, AdamOptimizerArgs, AdamWOptimizerArgs]
zero_stage: int
weight_decay: float
clip_grad: Optional[float]
Expand Down
2 changes: 2 additions & 0 deletions src/nanotron/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@

# NOTE: hacky, remove after working
IS_FP8: bool = True

NN_STATES = None
5 changes: 4 additions & 1 deletion src/nanotron/fp8/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
# wgrad.window_size = 1, ograd.window_size = 16 (this one is the input of the backward pass),
# input_grad.window_size = 1 (this one is the output of the backward pass)

# TODO(xrsrke): differentiate the precision that you initializes model weight
# and the accumulation precision in FP8 recipe
FP8LM_RECIPE = FP8TrainingRecipe(
# linear=FP8LinearRecipe(
# accum_dtype=DTypes.KFLOAT16,
Expand Down Expand Up @@ -74,7 +76,8 @@
input_grad=FP8TensorRecipe(dtype=DTypes.FP8E4M3, margin=0, interval=1),
weight_grad=FP8TensorRecipe(dtype=DTypes.FP8E4M3, margin=0, interval=1),
output_grad=FP8TensorRecipe(dtype=DTypes.FP8E5M2, margin=0, interval=1),
split_accumulator=FP8SplitAccumulator(output=False, input_grad=True, weight_grad=True),
# split_accumulator=FP8SplitAccumulator(output=False, input_grad=True, weight_grad=True), # NOTE: msamp use this
split_accumulator=FP8SplitAccumulator(output=True, input_grad=True, weight_grad=True),
),
optim=FP8OptimRecipe(
accum_dtype=DTypes.KFLOAT32,
Expand Down
9 changes: 7 additions & 2 deletions src/nanotron/fp8/linear.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import Optional, Tuple, Union, cast

import pydevd
# import pydevd
import torch
import transformer_engine as te # noqa
from torch import nn
Expand Down Expand Up @@ -58,6 +58,8 @@ def __init__(
)
assert quant_w.dtype in [torch.uint8, torch.int8], f"got {self.weight.data.dtype}"
self.weight = quant_w
# assert self.weight.data.orig_data.abs().max() == quant_w.fp8_meta.amax

assert self.weight.data.dtype in [torch.uint8, torch.int8], f"got {self.weight.data.dtype}"
self.metadatas = FP8LinearMeta()
self.accum_qtype = accum_qtype
Expand All @@ -72,6 +74,9 @@ def forward(self, input: Union[FP8Tensor, torch.Tensor]) -> torch.Tensor:

return F.linear(input, self.weight.data, self.bias, self.accum_qtype, self.metadatas)

def __repr__(self) -> str:
return f"FP8{super().__repr__()}"


class _FP8Matmul(torch.autograd.Function):
@staticmethod
Expand Down Expand Up @@ -143,7 +148,7 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[
∂L/∂W = Xᵀ @ ∂L/∂Y
Reference: https://web.eecs.umich.edu/~justincj/teaching/eecs442/notes/linear-backprop.html
"""
pydevd.settrace(suspend=False, trace_only_current_thread=True)
# pydevd.settrace(suspend=False, trace_only_current_thread=True)
fp8_input, fp8_weight = ctx.saved_tensors
accum_qtype = ctx.accum_qtype

Expand Down
4 changes: 3 additions & 1 deletion src/nanotron/fp8/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ def __new__(
assert dtype in [DTypes.FP8E4M3, DTypes.FP8E5M2, DTypes.KFLOAT16]

fp8_meta = cls._get_metadata(tensor, dtype, interval)

backup_fp8_meta = deepcopy(fp8_meta)
# else:
# assert tensor.dtype in FP8_DTYPES
# # fp8_tensor = tensor
Expand All @@ -220,7 +222,7 @@ def __new__(
# TODO(xrsrke): move update inverse scaling to FP8Meta's initialization
obj = torch.Tensor._make_subclass(cls, fp8_tensor)
# TODO(xrsrke): use a different name, because FP16Tensor also has fp8_meta
obj.fp8_meta = fp8_meta
obj.fp8_meta = backup_fp8_meta
obj.orig_data = tensor
return obj

Expand Down
16 changes: 16 additions & 0 deletions src/nanotron/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,22 @@ def optimizer(param_groups):
fused=False,
)

elif optimizer_args.optimizer_factory.name == "adam":

def optimizer(param_groups):
return torch.optim.Adam(
param_groups,
lr=optimizer_args.learning_rate_scheduler.learning_rate,
weight_decay=optimizer_args.weight_decay,
eps=optimizer_args.optimizer_factory.adam_eps,
betas=(optimizer_args.optimizer_factory.adam_beta1, optimizer_args.optimizer_factory.adam_beta2),
# fused=optimizer_args.optimizer_factory.torch_adam_is_fused,
# NOTE: fused (bool, optional) – whether the fused implementation (CUDA only) is used.
# Currently, torch.float64, torch.float32, torch.float16, and torch.bfloat16
# in FP8 training, model parameters are INT8
fused=False,
)

elif optimizer_args.optimizer_factory.name == "sgd":

def optimizer(param_groups):
Expand Down
Loading

0 comments on commit 59295ad

Please sign in to comment.