Skip to content

Commit

Permalink
add fp8 linear
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Oct 28, 2024
1 parent 6c9a4d0 commit 0f8f672
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 18 deletions.
74 changes: 74 additions & 0 deletions src/nanotron/config/fp8_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from dataclasses import dataclass
from typing import List, Optional

import torch

from nanotron.fp8.recipe import FP8LinearRecipe, FP8OptimRecipe


@dataclass
class FP8LayerArgs(FP8LinearRecipe):
module_name: Optional[str] = None

def __post_init__(self):
assert self.module_name is not None, "module_name must be specified"


@dataclass
class FP8Args:
# NOTE: this is the datatype for residual stream (aka: non-fp8 operation)
resid_dtype: torch.dtype = torch.float32
# NOTE: the datatype for fp8 operation's accumulation
accum_dtype: torch.dtype = torch.bfloat16

model: Optional[List[FP8LayerArgs]] = None
optim: Optional[FP8OptimRecipe] = None

run_fp8_sanity_check: bool = False

clipped_softmax: bool = False
clipped_softmax_zeta: Optional[float] = None
clipped_softmax_gamma: Optional[float] = None

gated_attention: bool = False

layer_scale: bool = False
layer_scale_init: Optional[str] = None
layer_scale_lr: Optional[float] = None
layer_scale_wdecay: Optional[float] = None

qk_norm: bool = False
qk_norm_before_pos: bool = False

smooth_quant: Optional[bool] = None
smooth_quant_migration_strength: Optional[float] = 0.5

stochastic_rounding: bool = False
update_clipping: bool = False
skip_param_update_if_nan: bool = False

sync_amax_in_input: bool = False
sync_amax_in_weight: bool = False
sync_amax_in_igrad: bool = False
sync_amax_in_wgrad: bool = False
sync_amax_func: str = "default"
weight_decay_without_lr_decay: bool = False

adam_atan2: bool = False
adam_atan2_lambda: Optional[float] = None

qkv_clipping: bool = False
qkv_clipping_factor: Optional[float] = None
is_save_grad_for_accum_debugging: bool = False
is_directly_keep_accum_grad_of_fp8: bool = False

triton_rms_norm: bool = False

is_debugging: bool = False
is_sanity_logging: bool = False
is_post_scaling_all_reduce: bool = True
# NOTE: 1.0e-6 was the default
gradient_clipping_eps: float = 1.0e-6

is_quant_all_except_first_and_last: Optional[bool] = None
fp8_linear_config_temp: Optional[FP8LayerArgs] = None
1 change: 1 addition & 0 deletions src/nanotron/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@

# TODO(xrsrke): remove this shit
ITERATION_STEP = 1
CONFIG = None
2 changes: 1 addition & 1 deletion src/nanotron/fp8/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
# and the accumulation precision in FP8 recipe

FP8LM_LINEAR_RECIPE = FP8LinearRecipe(
accum_dtype=torch.float16,
accum_dtype=torch.bfloat16,
input=FP8TensorRecipe(dtype=DTypes.FP8E4M3, margin=0, interval=16),
weight=FP8TensorRecipe(dtype=DTypes.FP8E4M3, margin=0, interval=1),
bias=torch.float16,
Expand Down
83 changes: 83 additions & 0 deletions src/nanotron/fp8/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import Optional, Tuple

import torch

from nanotron.fp8.linear import FP8LinearMeta
from nanotron.fp8.recipe import FP8LinearRecipe
from nanotron.fp8.tensor import FP8Tensor


def smooth_quant(input: torch.Tensor, weight: FP8Tensor, alpha: float) -> Tuple[torch.Tensor, FP8Tensor]:
"""
An implementation of SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models
https://arxiv.org/abs/2211.10438
"""
# Compute smoothing factor
input_s = torch.amax(torch.abs(input), dim=(0, 1), keepdim=True)
w_s = torch.amax(torch.abs(weight._orig_data_after_set_data), dim=0)

s = input_s.squeeze().pow(alpha) / w_s.pow(1 - alpha)

# NOTE: create a smoothed tensor without adding the smoothing operations
# to computational graph, and keep the original computational graph
X_smoothed = input.detach() / s.unsqueeze(dim=0).unsqueeze(dim=0)
X_smoothed.requires_grad_()

with torch.no_grad():
W_smoothed = weight._orig_data_after_set_data * s.unsqueeze(0)
weight.set_data(W_smoothed)

return X_smoothed, weight


def linear(
input: torch.Tensor,
weight: FP8Tensor,
bias: Optional[torch.Tensor] = None,
metadatas: FP8LinearMeta = None,
recipe: FP8LinearRecipe = None,
name: Optional[str] = None,
):
from typing import cast

from nanotron import constants
from nanotron.config.fp8_config import FP8Args

if recipe.smooth_quant is True:
fp8_config = cast(FP8Args, constants.CONFIG.fp8)
migration_strength = fp8_config.smooth_quant_migration_strength
input, weight = smooth_quant(input, weight, alpha=migration_strength)

assert metadatas is not None, "metadatas must be specified"
assert recipe is not None, "recipe must be specified"
assert input.device != torch.device("cpu"), "FP8Linear only supports CUDA tensors"

# TODO(xrsrke): refactor this out, don't duplicate the code
from einops import rearrange

from nanotron.fp8.linear import _FP8Matmul

seq_len = None
batch_size = None
is_input_flat = False
if input.ndim == 3:
batch_size = input.shape[0]
seq_len = input.shape[1]
is_input_flat = True
input = rearrange(input, "b n h -> (b n) h")
elif input.ndim > 3:
raise ValueError(f"Unsupported input shape: {input.shape}")

# NOTE: just a phony tensor to make pytorch trigger the backward pass
# because weight and bias's requires_grad are set to False
# so that we can compute the gradients using the fp8 kernels by ourselves
phony = torch.empty(0, device=input.device, requires_grad=True)
output = torch.zeros(input.shape[0], weight.shape[0], device="cuda", dtype=recipe.accum_dtype)
output, _ = _FP8Matmul.apply(input, weight, output, phony, metadatas, recipe, name)

# TODO(xrsrke): add support for adding bias in fp8
# TODO(xrsrke): support return an fp8 tensor as output
# since we will quantize it back to FP8 anyway in the next linear
output = rearrange(output, "(b n) h -> b n h", n=seq_len, b=batch_size) if is_input_flat is True else output
output = output if bias is None else output + bias
return output
2 changes: 1 addition & 1 deletion src/nanotron/fp8/kernel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import transformer_engine as te # noqa
import transformer_engine_extensions as tex
import transformer_engine_torch as tex

from nanotron.fp8.meta import FP8Meta
from nanotron.fp8.tensor import FP8Tensor
Expand Down
14 changes: 11 additions & 3 deletions src/nanotron/fp8/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def forward(self, input: Union[FP8Tensor, torch.Tensor]) -> torch.Tensor:
return F.linear(
input=input,
weight=get_data_from_param(self.weight),
bias=get_data_from_param(self.bias),
bias=None if self.bias is None else get_data_from_param(self.bias),
metadatas=self.metadatas,
recipe=self.recipe,
)
Expand Down Expand Up @@ -104,7 +104,11 @@ def forward(

# dist.monitored_barrier(wait_all_ranks=True)

fp8_config = cast(FP8Args, constants.CONFIG.fp8)
if constants.CONFIG is None:
fp8_config = FP8Args()
else:
fp8_config = cast(FP8Args, constants.CONFIG.fp8)

sync_amax_in_input = fp8_config.sync_amax_in_input

orig_input_shape = input.shape
Expand Down Expand Up @@ -159,7 +163,11 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[
from nanotron import constants
from nanotron.config.fp8_config import FP8Args

if constants.CONFIG.fp8 is not None and constants.CONFIG.fp8.is_debugging is True:
if (
constants.CONFIG is not None
and constants.CONFIG.fp8 is not None
and constants.CONFIG.fp8.is_debugging is True
):
pydevd.settrace(suspend=False, trace_only_current_thread=True)

# dist.monitored_barrier(wait_all_ranks=True)
Expand Down
4 changes: 3 additions & 1 deletion src/nanotron/parallel/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ def sanity_check(root_module: nn.Module):


def get_data_from_param(p: NanotronParameter):
assert p.__class__ == NanotronParameter
from nanotron.fp8.parameter import FP8Parameter

assert p.__class__ in [NanotronParameter, FP8Parameter]
# NOTE: this return the data that gradients can flow into
return p.data
18 changes: 6 additions & 12 deletions tests/fp8/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from nanotron.fp8.constants import FP8_DTYPES, QTYPE_TO_DTYPE
from nanotron.fp8.dtypes import DTypes
from nanotron.fp8.linear import FP8Linear, FP8LinearMeta
from nanotron.fp8.loss_scaler import LossScaler
from nanotron.fp8.parameter import FP8Parameter
from nanotron.fp8.recipe import FP8LinearRecipe
from nanotron.fp8.tensor import FP8Tensor, convert_tensor_from_fp8
Expand Down Expand Up @@ -38,7 +37,7 @@ def test_fp8_linear_parameters():
assert all(p.requires_grad for p in fp8_linear.parameters()) is True


@pytest.mark.skip
# @pytest.mark.skip
@pytest.mark.parametrize("n_layers", [1, 2])
@pytest.mark.parametrize(
"input",
Expand All @@ -49,8 +48,10 @@ def test_fp8_linear_parameters():
torch.randn(64, 64, 64, device="cuda", dtype=torch.float32), # [B, N, H]
],
)
@pytest.mark.parametrize("is_bias", [True, False])
@pytest.mark.parametrize("accum_qtype", [DTypes.KFLOAT32, DTypes.KFLOAT16])
# @pytest.mark.parametrize("is_bias", [True, False])
# @pytest.mark.parametrize("accum_qtype", [DTypes.KFLOAT32, DTypes.KFLOAT16, torch.bfloat16])
@pytest.mark.parametrize("is_bias", [False])
@pytest.mark.parametrize("accum_qtype", [torch.bfloat16])
def test_fp8_linear_forward_pass(n_layers, input, is_bias, accum_qtype):
HIDDEN_SIZE = 64

Expand Down Expand Up @@ -98,7 +99,6 @@ def test_fp8_linear_forward_pass(n_layers, input, is_bias, accum_qtype):
# )
# @pytest.mark.parametrize("is_bias", [True, False])
# @pytest.mark.skip
@pytest.mark.parametrize("with_scaler", [True, False])
@pytest.mark.parametrize("accum_qtype", [DTypes.KFLOAT32, DTypes.KFLOAT16])
def test_fp8_linear_backward_pass(n_layers, input, with_scaler, accum_qtype):
is_bias = False
Expand All @@ -115,8 +115,6 @@ def test_fp8_linear_backward_pass(n_layers, input, with_scaler, accum_qtype):
]
)

loss_scaler = LossScaler()

# trunc_normal_(ref_linear.weight, std=0.02)
# trunc_normal_(ref_linear.weight, std=math.sqrt(1 / (HIDDEN_SIZE)))

Expand All @@ -125,11 +123,7 @@ def test_fp8_linear_backward_pass(n_layers, input, with_scaler, accum_qtype):

ref_linear(ref_input).sum().backward()

if with_scaler is False:
fp8_linear(input).sum().backward()
else:
loss_scaler.scale(fp8_linear(input).sum()).backward()
loss_scaler.unscale_(fp8_linear.parameters())
fp8_linear(input).sum().backward()

for ref_p, p in zip(ref_linear.parameters(), fp8_linear.parameters()):
if p.requires_grad is False:
Expand Down

0 comments on commit 0f8f672

Please sign in to comment.