Skip to content

Commit

Permalink
add sync all reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Sep 18, 2024
1 parent e255511 commit d4a5997
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 70 deletions.
6 changes: 3 additions & 3 deletions examples/config_tiny_llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ model:
use_cache: true
vocab_size: 256
optimizer:
accumulate_grad_in_fp32: true
accumulate_grad_in_fp32: false
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
Expand All @@ -87,9 +87,9 @@ optimizer:
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 2
dp: 1
expert_parallel_size: 1
pp: 2
pp: 1
pp_engine: 1f1b
tp: 2
tp_linear_async_communication: true
Expand Down
117 changes: 98 additions & 19 deletions src/nanotron/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,26 +750,105 @@ def calculate_kurtosis(X):
return kurtosis


def compute_tensor_stats(tensor):
def compute_kurtosis_from_global_values(tensor, global_mean, global_var, tp_group):
"""
Compute kurtosis using global mean and variance across distributed ranks.
Arguments:
- tensor: local tensor on this rank.
- global_mean: the global mean after all-reduce sum.
- global_var: the global variance after all-reduce sum.
- tp_group: the process group for all-reduce operations.
Returns:
- kurtosis: the global kurtosis.
"""
# Step 1: Compute the fourth central moment (E[(X - global_mean)^4])
central_diff = tensor - global_mean
fourth_moment = torch.mean(central_diff**4)

# Perform all-reduce to get the global fourth moment sum
dist.all_reduce(fourth_moment, op=dist.ReduceOp.SUM, group=tp_group)

# Normalize the fourth moment across all ranks
global_fourth_moment = fourth_moment / tensor.numel()

# Step 2: Compute the global kurtosis
global_kurtosis = global_fourth_moment / (global_var**2)

# Handle edge case where the variance is zero
if torch.isnan(global_kurtosis):
global_kurtosis = torch.tensor(0.0)

return global_kurtosis


def compute_tensor_stats(tensor, parallel_context=None):
def compute_snr(tensor):
mean = torch.mean(tensor)
std = torch.std(tensor)
snr = mean / std
return snr

def compute_snr_from_global_mean_and_std(mean, std):
snr = mean / std
return snr

# mean = tensor.mean()
# dist.all_reduce(mean, op=dist.ReduceOp.SUM, group=parallel_context.world_pg)
mean = tensor.mean()
std = tensor.std()
var = tensor.var()
l1_norm = tensor.norm(p=1)
l2_norm = tensor.norm(p=2)
rms = tensor.pow(2).mean().sqrt()
min = tensor.min()
max = tensor.max()
abs_max = tensor.abs().max()

# NOTE: global mean
if parallel_context is not None:
tp_group = parallel_context.tp_pg
dist.all_reduce(mean, op=dist.ReduceOp.SUM, group=tp_group)
dist.all_reduce(std, op=dist.ReduceOp.SUM, group=tp_group)
dist.all_reduce(var, op=dist.ReduceOp.SUM, group=tp_group)
dist.all_reduce(l1_norm, op=dist.ReduceOp.SUM, group=tp_group)
dist.all_reduce(l2_norm, op=dist.ReduceOp.SUM, group=tp_group)
dist.all_reduce(rms, op=dist.ReduceOp.SUM, group=tp_group)
dist.all_reduce(min, op=dist.ReduceOp.MIN, group=tp_group)
dist.all_reduce(max, op=dist.ReduceOp.MAX, group=tp_group)
dist.all_reduce(abs_max, op=dist.ReduceOp.MAX, group=tp_group)

snr = compute_snr_from_global_mean_and_std(mean, std)
kurtosis = compute_kurtosis_from_global_values(tensor, mean, var, tp_group)

# return {
# "mean": tensor.mean(),
# "std": tensor.std(),
# "var": tensor.var(),
# "l1_norm": tensor.norm(p=1),
# "l2_norm": tensor.norm(p=2),
# "rms": tensor.pow(2).mean().sqrt(),
# "min": tensor.min(),
# "max": tensor.max(),
# "amax": tensor.abs().max(),
# # "abs_mean": tensor.abs().mean(),
# "kurtosis": calculate_kurtosis(tensor),
# "snr": compute_snr(tensor),
# }
return {
"mean": tensor.mean(),
"std": tensor.std(),
"var": tensor.var(),
"l1_norm": tensor.norm(p=1),
"l2_norm": tensor.norm(p=2),
"rms": tensor.pow(2).mean().sqrt(),
"min": tensor.min(),
"max": tensor.max(),
"amax": tensor.abs().max(),
"abs_mean": tensor.abs().mean(),
"kurtosis": calculate_kurtosis(tensor),
"snr": compute_snr(tensor),
"mean": mean,
"std": std,
"var": var,
"l1_norm": l1_norm,
"l2_norm": l2_norm,
"rms": rms,
"min": min,
"max": max,
"amax": abs_max,
# "abs_mean": tensor.abs().mean(),
"kurtosis": kurtosis,
"snr": snr,
}


Expand Down Expand Up @@ -832,15 +911,15 @@ def compute_stats(tensors, metrics: List[str] = ["amax", "mean", "std", "var", "
# stats[key] = {}
# for metric in metrics:
# stats[key][metric] = NAME_TO_FUNC[metric](tensor)
stats[key] = compute_tensor_stats(tensor)
stats[key] = compute_tensor_stats(tensor, parallel_context)

# NOTE: now all reduce mean this across tp ranks
from torch.distributed import ReduceOp
# from torch.distributed import ReduceOp

tp_group = parallel_context.tp_pg
for metric_name, metric_value in stats[key].items():
stats[key][metric_name] = torch.tensor(metric_value, device=tensor.device, dtype=tensor.dtype)
dist.all_reduce(stats[key][metric_name], op=ReduceOp.MAX, group=tp_group)
# tp_group = parallel_context.tp_pg
# for metric_name, metric_value in stats[key].items():
# stats[key][metric_name] = torch.tensor(metric_value, device=tensor.device, dtype=tensor.dtype)
# # dist.all_reduce(stats[key][metric_name], op=ReduceOp.MAX, group=tp_group)

return stats[list(stats.keys())[0]] if len(stats) == 1 else stats

Expand Down
41 changes: 28 additions & 13 deletions src/nanotron/optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,22 @@ def __setstate__(self, state):
group.setdefault("amsgrad", False)

def _get_optim_logs(self, loggings):
def _find_param_name(p):
if hasattr(self, "grad_accumulator"):
for name in self.grad_accumulator.parameters:
if self.grad_accumulator.parameters[name]["fp32"] is p:
return name
else:
return self.params_id_to_param_names[id(p)]

raise ValueError("Could not find the parameter name")

from nanotron.helpers import convert_logs_to_flat_logs

optim_loggings = {}
for p in loggings:
param_name = self.params_id_to_param_names[id(p)]
# param_name = self.params_id_to_param_names[id(p)]
param_name = _find_param_name(p)
optim_loggings[param_name] = loggings[p]
return convert_logs_to_flat_logs(optim_loggings)

Expand Down Expand Up @@ -171,22 +182,26 @@ def step(self, closure=None):
loggings[p]["bias_correction1"] = {"value": bias_correction1}
loggings[p]["bias_correction2"] = {"value": bias_correction2}

loggings[p]["exp_avg"] = compute_tensor_stats(exp_avg)
loggings[p]["exp_avg_sq"] = compute_tensor_stats(exp_avg_sq)
loggings[p]["exp_avg_hat"] = compute_tensor_stats(exp_avg_hat)
loggings[p]["exp_avg_sq_hat"] = compute_tensor_stats(exp_avg_sq_hat)
loggings[p]["exp_avg"] = compute_tensor_stats(exp_avg, self.parallel_context)
loggings[p]["exp_avg_sq"] = compute_tensor_stats(exp_avg_sq, self.parallel_context)
loggings[p]["exp_avg_hat"] = compute_tensor_stats(exp_avg_hat, self.parallel_context)
loggings[p]["exp_avg_sq_hat"] = compute_tensor_stats(exp_avg_sq_hat, self.parallel_context)

loggings[p]["normalized_grad"] = compute_tensor_stats(normalized_grad)
loggings[p]["normalized_grad"] = compute_tensor_stats(normalized_grad, self.parallel_context)
loggings[p]["normalized_grad_without_adam_eps"] = compute_tensor_stats(
normalized_grad_without_adam_eps
normalized_grad_without_adam_eps, self.parallel_context
)
loggings[p]["weight_decay_grad"] = compute_tensor_stats(weight_decay_grad)
loggings[p]["weight_decay_grad"] = compute_tensor_stats(weight_decay_grad, self.parallel_context)

loggings[p]["fp32_p"] = compute_tensor_stats(p.data)
loggings[p]["fp32_new_changes_in_p"] = compute_tensor_stats(total_new_weight_changes)
loggings[p]["fp32_new_changes_from_grad"] = compute_tensor_stats(new_weight_changes_from_grad)
loggings[p]["fp32_p"] = compute_tensor_stats(p.data, self.parallel_context)
loggings[p]["fp32_new_changes_in_p"] = compute_tensor_stats(
total_new_weight_changes, self.parallel_context
)
loggings[p]["fp32_new_changes_from_grad"] = compute_tensor_stats(
new_weight_changes_from_grad, self.parallel_context
)

loggings[p]["fp32_grad"] = compute_tensor_stats(grad)
loggings[p]["fp32_grad"] = compute_tensor_stats(grad, self.parallel_context)
loggings[p]["weight_norm_and_normalized_grad_update_norm_ratio"] = {
"value": weight_norm_and_normalized_grad_update_norm_ratio
}
Expand All @@ -198,7 +213,7 @@ def step(self, closure=None):

if group["weight_decay"] != 0:
loggings[p]["fp32_new_changes_from_weight_decay"] = compute_tensor_stats(
new_weight_changes_from_weight_decay
new_weight_changes_from_weight_decay, self.parallel_context
)
loggings[p]["weight_norm_and_weight_decay_update_norm_ratio"] = {
"value": weight_norm_and_weight_decay_update_norm_ratio
Expand Down
83 changes: 48 additions & 35 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,17 +257,6 @@ def __init__(

self.post_init()

self.params_id_to_param_names = {
id(param): name for name, param in self.unwrapped_model.get_named_params_with_correct_tied()
}
from nanotron.optim.optimizer_from_gradient_accumulator import OptimizerFromGradientAccumulator

if self.optimizer.__class__ == OptimizerFromGradientAccumulator:
self.optimizer.optimizer.optimizer.params_id_to_param_names = self.params_id_to_param_names
self.optimizer.optimizer.optimizer.grad_accumulator = self.grad_accumulator
else:
self.optimizer.optimizer.params_id_to_param_names = self.params_id_to_param_names

def pre_init(self):
pass

Expand Down Expand Up @@ -411,6 +400,19 @@ def train(
) -> None:
self.pre_training(**kwargs)

self.params_id_to_param_names = {
id(param): name for name, param in self.unwrapped_model.get_named_params_with_correct_tied()
}
from nanotron.optim.optimizer_from_gradient_accumulator import OptimizerFromGradientAccumulator

if self.optimizer.__class__ == OptimizerFromGradientAccumulator:
self.optimizer.optimizer.optimizer.params_id_to_param_names = self.params_id_to_param_names
self.optimizer.optimizer.optimizer.grad_accumulator = self.grad_accumulator
self.optimizer.optimizer.optimizer.parallel_context = self.parallel_context
else:
self.optimizer.optimizer.params_id_to_param_names = self.params_id_to_param_names
self.optimizer.optimizer.parallel_context = self.parallel_context

if self.config.checkpoints.save_initial_state and self.init_checkpoint_path is None:
self.save_checkpoint()

Expand All @@ -431,7 +433,9 @@ def train(

is_ready_for_normal_log = (
(self.iteration_step - 1) % self.config.logging.iteration_step_info_interval == 0
) and (dist.get_rank(self.parallel_context.world_pg) == 0)
) and (
dist.get_rank(self.parallel_context.dp_pg) == 0
) # NOTE: only log the first dp rank

prof = get_profiler(config=self.config)
torch.cuda.empty_cache()
Expand All @@ -440,7 +444,9 @@ def train(
if isinstance(prof, torch.profiler.profile):
prof.step()

is_ready_to_log = is_ready_for_normal_log and self.config.logging.monitor_fwd_states is True
is_ready_to_log = (
is_ready_for_normal_log and self.config.logging.monitor_fwd_states is True and wandb is not None
)
constants.is_ready_to_log = is_ready_to_log

if is_ready_to_log is True:
Expand All @@ -466,37 +472,44 @@ def train(
self.train_step_logs(outputs=outputs, loss_avg=loss_avg)

if is_ready_to_log is True:
if wandb is not None:
for handle in state_handles:
handle.remove()
for handle in state_handles:
handle.remove()

detailed_logs = {}
detailed_logs = {}

# optim_logs = get_optim_logs(self.params_id_to_param_names, self.optimizer.optimizer, prefix="")
if hasattr(self.optimizer.optimizer, "loggings"):
optim_logs = self.optimizer.optimizer.loggings
detailed_logs.update(optim_logs)
# optim_logs = get_optim_logs(self.params_id_to_param_names, self.optimizer.optimizer, prefix="")
if self.optimizer.__class__ == OptimizerFromGradientAccumulator:
optim_logs = self.optimizer.optimizer.optimizer.loggings
detailed_logs.update(optim_logs)
elif hasattr(self.optimizer.optimizer, "loggings"):
optim_logs = self.optimizer.optimizer.loggings
detailed_logs.update(optim_logs)

from nanotron import constants
from nanotron import constants

detailed_logs.update(convert_logs_to_flat_logs(constants.NN_STATES))
# NOTE: convert tensor to float for logging
detailed_logs = {
k: v.item() if isinstance(v, torch.Tensor) else v for k, v in detailed_logs.items()
}
detailed_logs.update(convert_logs_to_flat_logs(constants.NN_STATES))
# NOTE: convert tensor to float for logging
detailed_logs = {
k: v.item() if isinstance(v, torch.Tensor) else v for k, v in detailed_logs.items()
}

# NOTE: save the detailed logs, path/{run_name}/{iteration_step}/logs.json
# NOTE: save the detailed logs, path/{run_name}/{iteration_step}/logs.json

# DEBUG_SAVE_PATH = DEBUG_SAVE_PATH.format(self.config.general.run, self.iteration_step)
# debug_save_path = get_debug_save_path(self.config.general.run, self.iteration_step)
# with open(f"{debug_save_path}/logs.json", "w") as f:
# json.dump(detailed_logs, f)
# DEBUG_SAVE_PATH = DEBUG_SAVE_PATH.format(self.config.general.run, self.iteration_step)
# debug_save_path = get_debug_save_path(self.config.general.run, self.iteration_step)
# with open(f"{debug_save_path}/logs.json", "w") as f:
# json.dump(detailed_logs, f)

if dist.get_rank(self.parallel_context.world_pg) == 0:
# NOTE: all ranks has the same stats, only rank 0 logs it
wandb.log({**detailed_logs, "iteration_step": self.iteration_step})
constants.NN_STATES = {}

if hasattr(self.optimizer.optimizer, "loggings"):
self.optimizer.optimizer.loggings = {}
constants.NN_STATES = {}

if self.optimizer.__class__ == OptimizerFromGradientAccumulator:
self.optimizer.optimizer.optimizer.loggings = {}
elif hasattr(self.optimizer.optimizer, "loggings"):
self.optimizer.optimizer.loggings = {}

# Checkpoint
if self.iteration_step % self.config.checkpoints.checkpoint_interval == 0:
Expand Down

0 comments on commit d4a5997

Please sign in to comment.