Skip to content

Commit

Permalink
Merge pull request #139 from huggingface/bouteille/fix-weight-decay
Browse files Browse the repository at this point in the history
Add param group weight decay
  • Loading branch information
3outeille authored Apr 19, 2024
2 parents 96ac464 + 2de09d3 commit 4799d24
Show file tree
Hide file tree
Showing 12 changed files with 811 additions and 77 deletions.
11 changes: 7 additions & 4 deletions examples/config_tiny_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os

from nanotron.config import (
AdamWOptimizerArgs,
CheckpointsArgs,
Config,
DataArgs,
Expand Down Expand Up @@ -62,11 +63,13 @@
weight_decay=0.01,
clip_grad=1.0,
accumulate_grad_in_fp32=True,
adam_eps=1e-08,
adam_beta1=0.9,
adam_beta2=0.95,
torch_adam_is_fused=True,
learning_rate_scheduler=learning_rate,
optimizer_factory=AdamWOptimizerArgs(
adam_eps=1e-08,
adam_beta1=0.9,
adam_beta2=0.95,
torch_adam_is_fused=True,
),
)

parallelism = ParallelismArgs(
Expand Down
43 changes: 38 additions & 5 deletions examples/config_tiny_llama.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,34 @@
checkpoints:
checkpoint_interval: 10
checkpoints_path: /fsx/ferdinandmom/ferdinand-hf/nanotron/checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_initial_state: false
data_stages:
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small
hf_dataset_splits: train
text_column_name: completion
num_loading_workers: 1
seed: 42
name: Stable Training Stage
start_training_step: 1
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small
hf_dataset_splits: train
text_column_name: completion
num_loading_workers: 1
seed: 42
name: Annealing Phase
start_training_step: 10
general:
benchmark_csv_path: null
consumed_train_samples: null
Expand Down Expand Up @@ -34,9 +65,6 @@ model:
vocab_size: 256
optimizer:
accumulate_grad_in_fp32: true
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.001
Expand All @@ -46,8 +74,13 @@ optimizer:
lr_warmup_steps: 2000 # 20% of the total steps
lr_warmup_style: linear
min_decay_lr: 1.0e-05
torch_adam_is_fused: true
weight_decay: 0.1
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 2
Expand Down
30 changes: 21 additions & 9 deletions examples/mamba/create_config_mamba.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
""" Example python script to generate a YAML config file which can be used to run a training with nanotron. Refer to "examples" section in the `/README.md` for more information."""
import math
import os
import uuid

from config import MambaConfig, MambaInit, MambaModelConfig
from nanotron.config import (
AdamWOptimizerArgs,
CheckpointsArgs,
DataArgs,
DatasetStageArgs,
Expand All @@ -19,6 +21,10 @@
)
from nanotron.logging import human_format

new_job_id = uuid.uuid4()
job_id = str(new_job_id)[:8]
seed = 42

ssm_cfg_dtype = "bfloat16"
ssm_cfg = {
"d_state": 16,
Expand All @@ -37,7 +43,7 @@
# https://huggingface.co/state-spaces/mamba-790m/blob/main/config.json
model_config = MambaModelConfig(
d_model=1024,
num_hidden_layers=48,
num_hidden_layers=2,
vocab_size=50278,
ssm_cfg=ssm_cfg,
rms_norm=True,
Expand Down Expand Up @@ -88,24 +94,28 @@

seed = 42


optimizer = OptimizerArgs(
zero_stage=0,
weight_decay=0.01,
clip_grad=1.0,
accumulate_grad_in_fp32=True, # NOTE(fmom): because we are using PP=TP=DP=1
adam_eps=1e-08,
adam_beta1=0.9,
adam_beta2=0.95,
torch_adam_is_fused=True,
learning_rate_scheduler=LRSchedulerArgs(
learning_rate=0.0015,
lr_warmup_steps=30,
lr_warmup_style="linear",
lr_decay_style="cosine",
min_decay_lr=0.00015,
),
optimizer_factory=AdamWOptimizerArgs(
adam_eps=1e-08,
adam_beta1=0.9,
adam_beta2=0.95,
torch_adam_is_fused=True,
),
)


parallelism = ParallelismArgs(
dp=2,
pp=2,
Expand All @@ -128,17 +138,19 @@
)
]

model = ModelArgs(
init_method=MambaInit(initializer_range=0.02, rescale_prenorm_residual=True, n_residuals_per_layer=1),
model_config=model_config,
)

checkpoints_path = os.path.dirname(os.path.dirname(__file__)) + "/checkpoints"
os.makedirs(checkpoints_path, exist_ok=True)

config = MambaConfig(
general=GeneralArgs(project="test", run="mamba", seed=seed, ignore_sanity_checks=True),
checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=100),
parallelism=parallelism,
model=ModelArgs(
init_method=MambaInit(initializer_range=0.02, rescale_prenorm_residual=True, n_residuals_per_layer=1),
model_config=model_config,
),
model=model,
tokenizer=TokenizerArgs("gpt2"),
optimizer=optimizer,
logging=LoggingArgs(),
Expand Down
18 changes: 10 additions & 8 deletions examples/mamba/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,13 @@ def __init__(
self.A_log = create_sharded_parameter_from_config(
parameter=A_log, pg=self.tp_pg, split_config=SplitConfig(split_dim=0)
)
self.A_log._no_weight_decay = True

# D "skip" parameter
self.D = create_sharded_parameter_from_config(
parameter=torch.ones(self.d_inner // self.tp_pg.size(), device=device),
pg=self.tp_pg,
split_config=SplitConfig(split_dim=0),
)
self.D._no_weight_decay = True

# self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
self.out_proj = TensorParallelRowLinear(
Expand Down Expand Up @@ -664,7 +662,7 @@ def get_block_compute_costs(self):

def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size):
"""
Get flops per second for a Mamba model.
Get flops per second for a Mamba model.
Terms such as nonlinearities, biases, and layer normalization are omitted (https://arxiv.org/pdf/2001.08361.pdf)
"""
# world_size = self.parallel_context.world_pg.size()
Expand Down Expand Up @@ -723,6 +721,14 @@ def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch

return model_flops_per_s, hardware_flops_per_s

def get_named_params_without_weight_decay(self):
# get full name with "A_log", "D"
named_param_without_weight_decay = []
for name, _ in self.named_parameters():
if "A_log" in name or "D" in name:
named_param_without_weight_decay.append(name)
return named_param_without_weight_decay


def masked_mean(loss, label_mask, dtype):
# type: (Tensor, Tensor, torch.dtype) -> Tensor
Expand Down Expand Up @@ -917,11 +923,7 @@ def init_model_randomly(self, config):
raise ValueError(f"Who the fuck is {param_name}?")

elif isinstance(module, Mamba):
# NOTE(fmom): nn.Parameter are initialized in Mamba __init__
# In Mamba, only those 3 parameters don't have weight decay.
if param_name in ["dt_bias", "A_log", "D"]:
param._no_weight_decay = True

pass
else:
raise Exception(f"Parameter {full_param_name} was not initialized")

Expand Down
11 changes: 7 additions & 4 deletions examples/moe/config_llamoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional

from nanotron.config import (
AdamWOptimizerArgs,
CheckpointsArgs,
Config,
DataArgs,
Expand Down Expand Up @@ -99,11 +100,13 @@ def __post_init__(self):
weight_decay=0.01,
clip_grad=1.0,
accumulate_grad_in_fp32=False,
adam_eps=1e-08,
adam_beta1=0.9,
adam_beta2=0.95,
torch_adam_is_fused=True,
learning_rate_scheduler=learning_rate,
optimizer_factory=AdamWOptimizerArgs(
adam_eps=1e-08,
adam_beta1=0.9,
adam_beta2=0.95,
torch_adam_is_fused=True,
),
)

parallelism = ParallelismArgs(
Expand Down
46 changes: 30 additions & 16 deletions examples/moe/config_llamoe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,30 @@ checkpoints:
resume_checkpoint_path: /fsx/nouamane/projects/nanotron/examples/checkpoints
save_initial_state: true
data_stages:
- name: General purpose training
start_training_step: 1
data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 12
hf_dataset_config_name: null
hf_dataset_or_datasets: roneneldan/TinyStories
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 12
hf_dataset_config_name: null
hf_dataset_or_datasets: roneneldan/TinyStories
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: Stable Training Stage
start_training_step: 1
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 12
hf_dataset_config_name: null
hf_dataset_or_datasets: roneneldan/TinyStories
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: Annealing Phase
start_training_step: 10
general:
benchmark_csv_path: null
consumed_train_samples: null
Expand Down Expand Up @@ -60,9 +72,6 @@ model:
vocab_size: 32000
optimizer:
accumulate_grad_in_fp32: false
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
Expand All @@ -72,7 +81,12 @@ optimizer:
lr_warmup_steps: 100
lr_warmup_style: linear
min_decay_lr: 1.0e-05
torch_adam_is_fused: true
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
Expand Down
1 change: 1 addition & 0 deletions examples/moe/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
stanford-stk>=0.0.6
megablocks==0.5.1
21 changes: 15 additions & 6 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,20 +262,29 @@ def __post_init__(self):
self.min_decay_lr = self.learning_rate


@dataclass
class SGDOptimizerArgs:
name: str = "sgd"


@dataclass
class AdamWOptimizerArgs:
adam_eps: float
adam_beta1: float
adam_beta2: float
torch_adam_is_fused: bool
name: str = "adamW"


@dataclass
class OptimizerArgs:
"""Arguments related to the optimizer and learning rate"""

optimizer_factory: Union[SGDOptimizerArgs, AdamWOptimizerArgs]
zero_stage: int
weight_decay: float
clip_grad: Optional[float]

accumulate_grad_in_fp32: bool

adam_eps: float
adam_beta1: float
adam_beta2: float
torch_adam_is_fused: bool
learning_rate_scheduler: LRSchedulerArgs


Expand Down
Loading

0 comments on commit 4799d24

Please sign in to comment.