Skip to content

Commit

Permalink
add custom Adam for bfloat16
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Jul 9, 2024
1 parent 3376da7 commit 89a36f9
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 11 deletions.
8 changes: 4 additions & 4 deletions examples/fp8/ablations/configs/sanity_bf16.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ model:
intermediate_size: 2048
is_llama_config: true
max_position_embeddings: 256
num_attention_heads: 4
num_attention_heads: 16
num_hidden_layers: 2
num_key_value_heads: 4
num_key_value_heads: 16
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
Expand Down Expand Up @@ -123,7 +123,7 @@ optimizer:
lr_decay_starting_step: null
lr_decay_steps: null
lr_decay_style: cosine
lr_warmup_steps: 1000 # 10% warm up of total training steps
lr_warmup_steps: 200 # 10% warm up of total training steps
lr_warmup_style: linear
min_decay_lr: 0.00006
optimizer_factory:
Expand Down Expand Up @@ -158,7 +158,7 @@ tokens:
batch_accumulation_per_replica: 1
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 256 # 256
micro_batch_size: 128 # 256
# micro_batch_size: 1
sequence_length: 256
train_steps: 24376
Expand Down
3 changes: 2 additions & 1 deletion src/nanotron/fp8/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ def step(self, closure=None):
# print(f"[Ref Adam] exp_avg: {exp_avg[:2, :2]} \n")
# print(f"[Ref Adam] denom: {denom[:2, :2]} \n")

p.data.addcdiv_(-step_size, exp_avg, denom)
# p.data.addcdiv_(-step_size, exp_avg, denom)
p.data = p.data - step_size * (exp_avg / denom)

# if p.ndim != 1:
# print(f"[Ref Adam] updated p: {p.data[:2, :2]} \n")
Expand Down
24 changes: 18 additions & 6 deletions src/nanotron/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,24 @@ def optimizer(param_groups):
else:

def optimizer(param_groups):
return torch.optim.Adam(
# return torch.optim.Adam(
# param_groups,
# lr=optimizer_args.learning_rate_scheduler.learning_rate,
# weight_decay=optimizer_args.weight_decay,
# eps=optimizer_args.optimizer_factory.adam_eps,
# betas=(
# optimizer_args.optimizer_factory.adam_beta1,
# optimizer_args.optimizer_factory.adam_beta2,
# ),
# # fused=optimizer_args.optimizer_factory.torch_adam_is_fused,
# # NOTE: fused (bool, optional) – whether the fused implementation (CUDA only) is used.
# # Currently, torch.float64, torch.float32, torch.float16, and torch.bfloat16
# # in FP8 training, model parameters are INT8
# fused=False,
# )
from nanotron.fp8.optim import Adam

return Adam(
param_groups,
lr=optimizer_args.learning_rate_scheduler.learning_rate,
weight_decay=optimizer_args.weight_decay,
Expand All @@ -381,11 +398,6 @@ def optimizer(param_groups):
optimizer_args.optimizer_factory.adam_beta1,
optimizer_args.optimizer_factory.adam_beta2,
),
# fused=optimizer_args.optimizer_factory.torch_adam_is_fused,
# NOTE: fused (bool, optional) – whether the fused implementation (CUDA only) is used.
# Currently, torch.float64, torch.float32, torch.float16, and torch.bfloat16
# in FP8 training, model parameters are INT8
fused=False,
)

elif optimizer_args.optimizer_factory.name == "sgd":
Expand Down

0 comments on commit 89a36f9

Please sign in to comment.