-
Notifications
You must be signed in to change notification settings - Fork 118
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
180 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from dataclasses import dataclass | ||
from typing import List, Optional | ||
|
||
import torch | ||
|
||
from nanotron.fp8.recipe import FP8LinearRecipe, FP8OptimRecipe | ||
|
||
|
||
@dataclass | ||
class FP8LayerArgs(FP8LinearRecipe): | ||
module_name: Optional[str] = None | ||
|
||
def __post_init__(self): | ||
assert self.module_name is not None, "module_name must be specified" | ||
|
||
|
||
@dataclass | ||
class FP8Args: | ||
# NOTE: this is the datatype for residual stream (aka: non-fp8 operation) | ||
resid_dtype: torch.dtype = torch.float32 | ||
# NOTE: the datatype for fp8 operation's accumulation | ||
accum_dtype: torch.dtype = torch.bfloat16 | ||
|
||
model: Optional[List[FP8LayerArgs]] = None | ||
optim: Optional[FP8OptimRecipe] = None | ||
|
||
run_fp8_sanity_check: bool = False | ||
|
||
clipped_softmax: bool = False | ||
clipped_softmax_zeta: Optional[float] = None | ||
clipped_softmax_gamma: Optional[float] = None | ||
|
||
gated_attention: bool = False | ||
|
||
layer_scale: bool = False | ||
layer_scale_init: Optional[str] = None | ||
layer_scale_lr: Optional[float] = None | ||
layer_scale_wdecay: Optional[float] = None | ||
|
||
qk_norm: bool = False | ||
qk_norm_before_pos: bool = False | ||
|
||
smooth_quant: Optional[bool] = None | ||
smooth_quant_migration_strength: Optional[float] = 0.5 | ||
|
||
stochastic_rounding: bool = False | ||
update_clipping: bool = False | ||
skip_param_update_if_nan: bool = False | ||
|
||
sync_amax_in_input: bool = False | ||
sync_amax_in_weight: bool = False | ||
sync_amax_in_igrad: bool = False | ||
sync_amax_in_wgrad: bool = False | ||
sync_amax_func: str = "default" | ||
weight_decay_without_lr_decay: bool = False | ||
|
||
adam_atan2: bool = False | ||
adam_atan2_lambda: Optional[float] = None | ||
|
||
qkv_clipping: bool = False | ||
qkv_clipping_factor: Optional[float] = None | ||
is_save_grad_for_accum_debugging: bool = False | ||
is_directly_keep_accum_grad_of_fp8: bool = False | ||
|
||
triton_rms_norm: bool = False | ||
|
||
is_debugging: bool = False | ||
is_sanity_logging: bool = False | ||
is_post_scaling_all_reduce: bool = True | ||
# NOTE: 1.0e-6 was the default | ||
gradient_clipping_eps: float = 1.0e-6 | ||
|
||
is_quant_all_except_first_and_last: Optional[bool] = None | ||
fp8_linear_config_temp: Optional[FP8LayerArgs] = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,3 +14,4 @@ | |
|
||
# TODO(xrsrke): remove this shit | ||
ITERATION_STEP = 1 | ||
CONFIG = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
|
||
from nanotron.fp8.linear import FP8LinearMeta | ||
from nanotron.fp8.recipe import FP8LinearRecipe | ||
from nanotron.fp8.tensor import FP8Tensor | ||
|
||
|
||
def smooth_quant(input: torch.Tensor, weight: FP8Tensor, alpha: float) -> Tuple[torch.Tensor, FP8Tensor]: | ||
""" | ||
An implementation of SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models | ||
https://arxiv.org/abs/2211.10438 | ||
""" | ||
# Compute smoothing factor | ||
input_s = torch.amax(torch.abs(input), dim=(0, 1), keepdim=True) | ||
w_s = torch.amax(torch.abs(weight._orig_data_after_set_data), dim=0) | ||
|
||
s = input_s.squeeze().pow(alpha) / w_s.pow(1 - alpha) | ||
|
||
# NOTE: create a smoothed tensor without adding the smoothing operations | ||
# to computational graph, and keep the original computational graph | ||
X_smoothed = input.detach() / s.unsqueeze(dim=0).unsqueeze(dim=0) | ||
X_smoothed.requires_grad_() | ||
|
||
with torch.no_grad(): | ||
W_smoothed = weight._orig_data_after_set_data * s.unsqueeze(0) | ||
weight.set_data(W_smoothed) | ||
|
||
return X_smoothed, weight | ||
|
||
|
||
def linear( | ||
input: torch.Tensor, | ||
weight: FP8Tensor, | ||
bias: Optional[torch.Tensor] = None, | ||
metadatas: FP8LinearMeta = None, | ||
recipe: FP8LinearRecipe = None, | ||
name: Optional[str] = None, | ||
): | ||
from typing import cast | ||
|
||
from nanotron import constants | ||
from nanotron.config.fp8_config import FP8Args | ||
|
||
if recipe.smooth_quant is True: | ||
fp8_config = cast(FP8Args, constants.CONFIG.fp8) | ||
migration_strength = fp8_config.smooth_quant_migration_strength | ||
input, weight = smooth_quant(input, weight, alpha=migration_strength) | ||
|
||
assert metadatas is not None, "metadatas must be specified" | ||
assert recipe is not None, "recipe must be specified" | ||
assert input.device != torch.device("cpu"), "FP8Linear only supports CUDA tensors" | ||
|
||
# TODO(xrsrke): refactor this out, don't duplicate the code | ||
from einops import rearrange | ||
|
||
from nanotron.fp8.linear import _FP8Matmul | ||
|
||
seq_len = None | ||
batch_size = None | ||
is_input_flat = False | ||
if input.ndim == 3: | ||
batch_size = input.shape[0] | ||
seq_len = input.shape[1] | ||
is_input_flat = True | ||
input = rearrange(input, "b n h -> (b n) h") | ||
elif input.ndim > 3: | ||
raise ValueError(f"Unsupported input shape: {input.shape}") | ||
|
||
# NOTE: just a phony tensor to make pytorch trigger the backward pass | ||
# because weight and bias's requires_grad are set to False | ||
# so that we can compute the gradients using the fp8 kernels by ourselves | ||
phony = torch.empty(0, device=input.device, requires_grad=True) | ||
output = torch.zeros(input.shape[0], weight.shape[0], device="cuda", dtype=recipe.accum_dtype) | ||
output, _ = _FP8Matmul.apply(input, weight, output, phony, metadatas, recipe, name) | ||
|
||
# TODO(xrsrke): add support for adding bias in fp8 | ||
# TODO(xrsrke): support return an fp8 tensor as output | ||
# since we will quantize it back to FP8 anyway in the next linear | ||
output = rearrange(output, "(b n) h -> b n h", n=seq_len, b=batch_size) if is_input_flat is True else output | ||
output = output if bias is None else output + bias | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters