From 481baecd59557e806766807ad1fd1cbfccc22d87 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 19 Aug 2024 15:37:17 +0000 Subject: [PATCH 01/50] Changed version to 1.11.0.dev0 Signed-off-by: Kirthi Shankar Sivamani --- build_tools/VERSION.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index a597619ec0..12790c22b4 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -1.10.0.dev0 +1.11.0.dev0 From 350a4ff1946e2b9562fc5850b3a8ca72ae799916 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Mon, 19 Aug 2024 08:49:08 -0700 Subject: [PATCH 02/50] Remove the commit hash from the release documentation (#1118) Signed-off-by: Przemek Tredak --- docs/conf.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index 77751994d8..7a50ce76cf 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -47,7 +47,10 @@ git_sha = git_sha[:7] if len(git_sha) > 7 else git_sha -version = str(te_version + "-" + git_sha) +if "dev" in te_version: + version = str(te_version + "-" + git_sha) +else: + version = str(te_version) release = te_version # hack: version is used for html creation, so put the version picker From ee541e83b7088db1580907450a4f88491c8ccaf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Bastien?= Date: Mon, 19 Aug 2024 12:21:22 -0400 Subject: [PATCH 03/50] Stop using global mesh for custom_partitioning. (#1112) Signed-off-by: Frederic Bastien Co-authored-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> --- .../jax/cpp_extensions/activation.py | 2 +- .../jax/cpp_extensions/attention.py | 2 +- .../jax/cpp_extensions/normalization.py | 10 +++++----- .../jax/cpp_extensions/quantization.py | 2 +- .../jax/cpp_extensions/transpose.py | 12 ++++++------ transformer_engine/jax/sharding.py | 17 ++++++++--------- 6 files changed, 22 insertions(+), 23 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 56359646b1..47483c67ea 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -467,7 +467,7 @@ def sharded_impl(x, amax, scale, scale_inv): local_x, local_amax = ActLuFp8Primitive.impl( x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum ) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) return local_x, global_updated_amax diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 76ccec363b..0cbf847dcd 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1011,7 +1011,7 @@ def sharded_impl( ) global_dbias = local_dbias if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: - global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) + global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) return local_dq, local_dk, local_dv, global_dbias return mesh, sharded_impl, out_shardings, arg_shardings diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index caf9272b02..e85f28a06a 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -533,8 +533,8 @@ def sharded_impl(dz, x, mu, rsigma, gamma): local_dx, local_dgamma, local_dbeta = LayerNormBwdPrimitive.impl( dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon ) - global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma) - global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta) + global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma, mesh) + global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta, mesh) return local_dx, global_dgamma, global_dbeta return mesh, sharded_impl, out_shardings, arg_shardings @@ -935,7 +935,7 @@ def partition(epsilon, mesh, arg_infos, result_infos): def sharded_impl(dz, x, rsigma, gamma): local_dx, local_dgamma = RmsNormBwdPrimitive.impl(dz, x, rsigma, gamma, epsilon=epsilon) - global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma) + global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma, mesh) return local_dx, global_dgamma return mesh, sharded_impl, out_shardings, arg_shardings @@ -1228,7 +1228,7 @@ def sharded_impl(x, gamma, beta, amax, scale, scale_inv): zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, ) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) return local_x, local_mu, local_rsigma, global_updated_amax @@ -1481,7 +1481,7 @@ def sharded_impl(x, gamma, amax, scale, scale_inv): local_x, local_rsigma, local_amax = RmsNormFwdFp8Primitive.impl( x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon ) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) return local_x, local_rsigma, global_updated_amax diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 2c529e71c8..48bf4d969a 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -157,7 +157,7 @@ def sharded_impl(x, amax, scale, scale_inv): local_cx, local_updated_amax = CastFP8Primitive.impl( x, amax, scale, scale_inv, out_dtype=out_dtype ) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh) return local_cx, global_updated_amax diff --git a/transformer_engine/jax/cpp_extensions/transpose.py b/transformer_engine/jax/cpp_extensions/transpose.py index e503792dc0..963d7f09e8 100644 --- a/transformer_engine/jax/cpp_extensions/transpose.py +++ b/transformer_engine/jax/cpp_extensions/transpose.py @@ -390,7 +390,7 @@ def sharded_impl(x, amax, scale, scale_inv): static_axis_boundary=static_axis_boundary, transpose_axis_boundary=transpose_axis_boundary, ) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh) return local_cx, local_cxt, global_updated_amax @@ -646,8 +646,8 @@ def sharded_impl(dz, amax, scale, scale_inv): static_axis_boundary=static_axis_boundary, transpose_axis_boundary=transpose_axis_boundary, ) - global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) return local_out, local_t_out, global_dbias, global_updated_amax return mesh, sharded_impl, out_shardings, arg_shardings @@ -981,8 +981,8 @@ def sharded_impl(dz, x, amax, scale, scale_inv): act_enum=act_enum, ) ) - global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) return local_out, local_t_out, global_dbias, global_updated_amax return mesh, sharded_impl, out_shardings, arg_shardings @@ -1225,7 +1225,7 @@ def sharded_impl(dz, x, amax, scale, scale_inv): static_axis_boundary=static_axis_boundary, act_enum=act_enum, ) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) return local_out, local_t_out, global_updated_amax return mesh, sharded_impl, out_shardings, arg_shardings diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index c0b60fe61e..586e1a70c9 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -30,8 +30,7 @@ W_JOINED_AXES = "nvte_w_joined" -def _get_mesh_info(resource: str): - mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh +def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh): assert resource in mesh.axis_names, f"{resource} is not in the axis_names of Mesh {mesh}." return mesh.shape[resource], resource @@ -132,12 +131,12 @@ def get_padded_spec(spec, ndim): return spec + (None,) * (ndim - len(spec)) -def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str): +def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str, mesh: jax.sharding.Mesh): """ A wrapper function to invoke lax.p* operations, like psum. """ if mesh_resource is not None: - _, resource = _get_mesh_info(mesh_resource) + _, resource = _get_mesh_info(mesh_resource, mesh) return ops(x, resource) return x @@ -201,22 +200,22 @@ def global_mesh_resource() -> MeshResource: return _GLOBAL_MESH_RESOURCE -def all_reduce_sum_along_dp_fsdp(x: jnp.array): +def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh): """ All-Reduce (Sum) along DP and FSDP mesh axes. """ - x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource) - return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource) + x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource, mesh) + return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh) -def all_reduce_max_along_all_axes_except_PP(x: jnp.array): +def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mesh): """ All-Reduce (Max) along all mesh axes. """ all_axes = get_all_mesh_axes() for axis in all_axes: if axis != global_mesh_resource().pp_resource: - x = lax_paral_op(x, jax.lax.pmax, axis) + x = lax_paral_op(x, jax.lax.pmax, axis, mesh) return x From 5d5fe819b8d848b1cddcb4c10eba453356c901f4 Mon Sep 17 00:00:00 2001 From: hXl3s Date: Tue, 20 Aug 2024 19:01:37 +0200 Subject: [PATCH 04/50] feat(pytorch): Allow TransformerLayer and MultiheadAttention to accept sequence length parameters (#1066) * Added ability for seqlen for transformer and mha layer Signed-off-by: Lukasz Pierscieniewski * Documentation for new parameters Signed-off-by: Lukasz Pierscieniewski * Add tests for THD layout, assert for THD layout with KV-Cache Signed-off-by: Lukasz Pierscieniewski * Fixed tests Signed-off-by: Lukasz Pierscieniewski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Move THD logic in shape calculation, add missing optional in params Signed-off-by: Lukasz Pierscieniewski * Skip the THD test on GPUs older than Ampere Signed-off-by: Przemek Tredak --------- Signed-off-by: Lukasz Pierscieniewski Signed-off-by: Przemek Tredak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani Co-authored-by: Przemek Tredak --- tests/pytorch/test_numerics.py | 47 ++++++++++++++++++- transformer_engine/pytorch/attention.py | 44 +++++++++++++---- .../pytorch/module/layernorm_mlp.py | 3 +- transformer_engine/pytorch/transformer.py | 20 ++++++++ 4 files changed, 102 insertions(+), 12 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a219f24674..a2023f539a 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -34,11 +34,13 @@ from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace +from transformer_engine.pytorch.utils import get_device_compute_capability import transformer_engine_torch as tex # Only run FP8 tests on H100. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +sm_80plus = get_device_compute_capability() >= (8, 0) seed = 1234 torch.manual_seed(seed) @@ -1548,8 +1550,29 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): attn_input_format="bshd", ) - for (n1, p1), (n2, p2) in zip(block_bshd.named_parameters(), block_sbhd.named_parameters()): - assert torch.all(torch.eq(p1, p2)), f"{n1}, {n2} not identical" + torch.manual_seed(0) + block_thd = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.eps, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0, + attention_dropout=0, + kv_channels=config.embed, + params_dtype=dtype, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + device="cuda", + attn_input_format="thd", + self_attn_mask_type="padding_causal", + ) + + for (n1, p1), (n2, p2), (n3, p3) in zip( + block_bshd.named_parameters(), block_sbhd.named_parameters(), block_thd.named_parameters() + ): + assert torch.all(torch.eq(p1, p2) & torch.eq(p1, p3)), f"{n1}, {n2} and {n3} not identical" x_sbhd = torch.randn( (config.seq_len, bs, config.hidden_size), @@ -1559,6 +1582,8 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ) x_bshd = x_sbhd.transpose(0, 1).contiguous() + x_thd = x_bshd.reshape(bs * config.seq_len, config.hidden_size).contiguous() + x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.seq_len # To make sure forward is also identical (just in case some module decides # to act fancy) @@ -1576,6 +1601,24 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): y_sbhd.transpose(0, 1).contiguous(), ) + # THD is not supported in float32 and on GPUs older than Ampere, skip the test here + if dtype != torch.float32 and sm_80plus: + # To make sure forward is also identical (just in case some module decides + # to act fancy) + torch.manual_seed(0) + y_thd = block_thd( + x_thd, + cu_seqlens_q=x_thd_cumsum, + cu_seqlens_kv=x_thd_cumsum, + max_seqlen_q=config.seq_len, + max_seqlen_kv=config.seq_len, + ) + + torch.testing.assert_close( + y_bshd, + y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(), + ) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 904dbbde01..71bc15fdad 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -7048,6 +7048,10 @@ def forward( core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, fast_zero_fill: bool = True, ) -> Tuple[Union[torch.Tensor, None], ...]: """ @@ -7113,6 +7117,18 @@ def forward( ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) to the attention score of query i and key j. + cu_seqlens_q: Optional[torch.Tensor], default = `None` + Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`, + with shape [batch_size + 1] and dtype torch.int32. + cu_seqlens_kv: Optional[torch.Tensor], default = `None` + Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` + and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. + max_seqlen_q: Optional[int], default = `None` + Maximum sequence length in `query_layer`. + Calculated from `cu_seqlens_q` if not provided. + max_seqlen_kv: Optional[int], default = `None` + Maximum sequence length in `key_layer` and `value_layer`. + Calculated from `cu_seqlens_kv` if not provided. fast_zero_fill: bool, default = `True` Whether to set output tensors to 0 or not before use. """ @@ -7139,6 +7155,9 @@ def forward( # ================================================= if inference_params and self.layer_number is not None: + assert ( + self.qkv_format != "thd" + ), "qkv_format == thd is not supported for an inference with KV-cache!" if self.layer_number not in inference_params.key_value_memory_dict: inf_max_seq_len = inference_params.max_sequence_length inf_max_batch_size = inference_params.max_batch_size @@ -7221,13 +7240,18 @@ def forward( dim=split_dim, ) - # query: -> [sq, b, np, hn] - # key, value: -> [sq, b, ng, hn] - query_layer, key_layer, value_layer = ( - x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head) - for x in (query_layer, key_layer, value_layer) - ) - + if self.qkv_format == "thd": + query_layer, key_layer, value_layer = ( + x.reshape(x.size(0), -1, self.hidden_size_per_attention_head) + for x in (query_layer, key_layer, value_layer) + ) + else: + # query: -> [sq, b, np, hn] + # key, value: -> [sq, b, ng, hn] + query_layer, key_layer, value_layer = ( + x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head) + for x in (query_layer, key_layer, value_layer) + ) elif self.attention_type == "cross": # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)] mixed_kv_layer = self.key_value( @@ -7341,8 +7365,10 @@ def forward( key_layer, value_layer, qkv_format=self.qkv_format, - cu_seqlens_q=None, - cu_seqlens_kv=None, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, attention_mask=attention_mask, attn_mask_type=attn_mask_type, window_size=window_size, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index be6df21322..dc9bef645f 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -13,6 +13,7 @@ from .base import ( get_workspace, + _ub_communicators, get_ub, TransformerEngineBaseModule, _2X_ACC_FPROP, @@ -1297,7 +1298,7 @@ def __init__( self.gemm_gelu_fusion = ( bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) and self.activation == "gelu" - and not get_ub("fc1_fprop").is_atomic_gemm() + and ((_ub_communicators is None) or (not get_ub("fc1_fprop").is_atomic_gemm())) ) if tp_group is None: diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index f026da23ef..4cbee3d628 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -529,6 +529,10 @@ def forward( core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, fast_zero_fill: bool = True, ) -> torch.Tensor: """ @@ -604,6 +608,18 @@ def forward( ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) to the attention score of query i and key j. + cu_seqlens_q: Optional[torch.Tensor], default = `None` + Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`, + with shape [batch_size + 1] and dtype torch.int32. + cu_seqlens_kv: Optional[torch.Tensor], default = `None` + Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` + and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. + max_seqlen_q: Optional[int], default = `None` + Maximum sequence length in `query_layer`. + Calculated from `cu_seqlens_q` if not provided. + max_seqlen_kv: Optional[int], default = `None` + Maximum sequence length in `key_layer` and `value_layer`. + Calculated from `cu_seqlens_kv` if not provided. fast_zero_fill: bool, default = `True` Whether to set output tensors to 0 or not before use. inference_params: InferenceParams, default = None @@ -664,6 +680,10 @@ def forward( core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, fast_zero_fill=fast_zero_fill, ) From 8e3561bf8fc27fab6dbbdfe3daf67e7421fbe7f6 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Tue, 20 Aug 2024 18:05:48 -0700 Subject: [PATCH 05/50] Update FP8 scale-inverse in kernels with FP8 output (#1083) * Perform scale-inv update in cast-transpose kernels Signed-off-by: Tim Moon * Perform scale-inv update in cast and activation kernels Signed-off-by: Tim Moon * Perform sclae-inv update in LayerNorm and RMSNorm kernels Signed-off-by: Tim Moon * Perform scale-inv update after FP8 GEMMs Signed-off-by: Tim Moon * Fuse casts and scale-inv updates in linear module Signed-off-by: Tim Moon * Fuse casts and scale-inv updates in layernorm-linear module Signed-off-by: Tim Moon * Simplify kernel to update FP8 scale-inv Signed-off-by: Tim Moon * Fix typos Signed-off-by: Tim Moon * Debug amax update in layernorm kernels Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Debug test failures Signed-off-by: Tim Moon * Debug ONNX export Use quantization scaling factor in ONNX quantize op. Signed-off-by: Tim Moon * Review suggestion from @ptrendx Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Debug mismatched dtypes Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/cpp/operator/test_cast_transpose.cu | 2 + .../cpp/operator/test_cast_transpose_dbias.cu | 2 + .../test_cast_transpose_dbias_dgelu.cu | 2 + tests/cpp/operator/test_layernorm.cu | 2 + .../cpp/operator/test_multi_cast_transpose.cu | 4 + tests/cpp/operator/test_rmsnorm.cu | 2 + transformer_engine/common/CMakeLists.txt | 1 + .../common/activation/activation_template.h | 9 +- transformer_engine/common/common.cu | 32 +++ transformer_engine/common/common.h | 7 + .../common/gemm/cublaslt_gemm.cu | 5 + transformer_engine/common/layer_norm/ln.h | 3 + .../common/layer_norm/ln_api.cpp | 1 + .../common/layer_norm/ln_fwd_kernels.cuh | 32 ++- .../common/rmsnorm/rmsnorm_api.cpp | 1 + .../common/rmsnorm/rmsnorm_fwd_kernels.cuh | 32 ++- .../common/transpose/cast_transpose.cu | 28 +- .../common/transpose/cast_transpose_fusion.cu | 54 ++-- .../common/transpose/multi_cast_transpose.cu | 11 +- .../common/transpose/rtc/cast_transpose.cu | 7 +- .../transpose/rtc/cast_transpose_fusion.cu | 14 +- transformer_engine/common/util/cast.cu | 5 +- .../common/util/vectorized_pointwise.h | 95 ++++--- transformer_engine/common/utils.cuh | 5 + .../pytorch/cpp_extensions/_common.py | 87 ++++++ .../pytorch/cpp_extensions/activation.py | 257 ++++++++++-------- .../pytorch/cpp_extensions/cast.py | 90 ++++-- .../pytorch/cpp_extensions/normalization.py | 135 ++++++--- .../pytorch/cpp_extensions/transpose.py | 123 ++++++--- transformer_engine/pytorch/float8_tensor.py | 28 +- transformer_engine/pytorch/module/_common.py | 9 + .../pytorch/module/layernorm_linear.py | 37 ++- transformer_engine/pytorch/module/linear.py | 30 +- .../pytorch/te_onnx_extensions.py | 52 ++-- 34 files changed, 824 insertions(+), 380 deletions(-) create mode 100644 transformer_engine/common/common.cu create mode 100644 transformer_engine/pytorch/cpp_extensions/_common.py diff --git a/tests/cpp/operator/test_cast_transpose.cu b/tests/cpp/operator/test_cast_transpose.cu index 8c168c76f4..39a6614179 100644 --- a/tests/cpp/operator/test_cast_transpose.cu +++ b/tests/cpp/operator/test_cast_transpose.cu @@ -69,6 +69,8 @@ void performTest(const size_t N, const size_t H) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output_c.scale(); + compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); diff --git a/tests/cpp/operator/test_cast_transpose_dbias.cu b/tests/cpp/operator/test_cast_transpose_dbias.cu index a2c8594730..651508c871 100644 --- a/tests/cpp/operator/test_cast_transpose_dbias.cu +++ b/tests/cpp/operator/test_cast_transpose_dbias.cu @@ -116,6 +116,8 @@ void performTest(const size_t N, const size_t H) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output_c.scale(); + compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); diff --git a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu index 5920f23f38..38ac955bc9 100644 --- a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu +++ b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu @@ -132,6 +132,8 @@ void performTest(const size_t N, const size_t H) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output_c.scale(); + compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); diff --git a/tests/cpp/operator/test_layernorm.cu b/tests/cpp/operator/test_layernorm.cu index 07a8a8884c..cdd8e7846c 100644 --- a/tests/cpp/operator/test_layernorm.cu +++ b/tests/cpp/operator/test_layernorm.cu @@ -230,6 +230,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma) auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); if (isFp8Type(otype)) { compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / z.scale(); + compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); diff --git a/tests/cpp/operator/test_multi_cast_transpose.cu b/tests/cpp/operator/test_multi_cast_transpose.cu index 85fc3a573a..e7fb183217 100644 --- a/tests/cpp/operator/test_multi_cast_transpose.cu +++ b/tests/cpp/operator/test_multi_cast_transpose.cu @@ -139,6 +139,10 @@ void performTest() { output_c_list[tensor_id].amax(), ref_amax_list[tensor_id], atol_amax, rtol_amax); + compareResults("scale_inv", + output_c_list[tensor_id].scale_inv(), + 1.f / output_c_list[tensor_id].scale(), + atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); compareResults("output_c", diff --git a/tests/cpp/operator/test_rmsnorm.cu b/tests/cpp/operator/test_rmsnorm.cu index e4e34bac8a..0ec3a877e5 100644 --- a/tests/cpp/operator/test_rmsnorm.cu +++ b/tests/cpp/operator/test_rmsnorm.cu @@ -187,6 +187,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma) auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); if (isFp8Type(otype)) { compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / z.scale(); + compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 7fab75dca0..a4497751f4 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -44,6 +44,7 @@ set(transformer_engine_SOURCES) list(APPEND transformer_engine_SOURCES pycudnn.cpp transformer_engine.cpp + common.cu transpose/cast_transpose.cu transpose/transpose.cu transpose/cast_transpose_fusion.cu diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index fc93705dff..6184e235bd 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -27,7 +27,8 @@ void act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) { reinterpret_cast(input.data.dptr), reinterpret_cast(output->data.dptr), reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), tot_elts, {}, + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), tot_elts, {}, stream);); // NOLINT(*) ); // NOLINT(*) } @@ -50,7 +51,8 @@ void dact_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream reinterpret_cast(input.data.dptr), reinterpret_cast(output->data.dptr), reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), tot_elts, {}, + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), tot_elts, {}, stream);); // NOLINT(*) ); // NOLINT(*) } @@ -74,7 +76,8 @@ void gated_act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) { reinterpret_cast(input.data.dptr), reinterpret_cast(output->data.dptr), reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), output->data.shape[0], + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), output->data.shape[0], output->data.shape[1], {}, stream);); // NOLINT(*) ); // NOLINT(*) diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu new file mode 100644 index 0000000000..4e95fc24de --- /dev/null +++ b/transformer_engine/common/common.cu @@ -0,0 +1,32 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "./common.h" +#include "./utils.cuh" + +namespace transformer_engine { + +namespace { + +__global__ void __launch_bounds__(1) + update_tensor_scale_inv_kernel(const float* __restrict__ scale_ptr, + float* __restrict__ scale_inv_ptr) { + const float scale = scale_ptr == nullptr ? 1 : *scale_ptr; + reciprocal(scale_inv_ptr, scale); +} + +} // namespace + +void update_tensor_scale_inv(Tensor* t, cudaStream_t stream) { + if (t->scale_inv.dptr != nullptr) { + update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>( + reinterpret_cast(t->scale.dptr), reinterpret_cast(t->scale_inv.dptr)); + } +} + +} // namespace transformer_engine diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 42b529f388..7e72e1b031 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -262,6 +262,13 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt bool is_fp8_dtype(const DType t); +/*! \brief Update a tensor's FP8 scale-inverse + * + * The FP8 scale-inverse (dequantization scaling factor) is updated + * with the reciprocal of the FP8 scale (quantization scaling factor). + */ +void update_tensor_scale_inv(Tensor *t, cudaStream_t stream); + #define NVTE_API_CALL(api_name) \ transformer_engine::nvtx::NVTXWrapper _##api_name##_nvtx_wrapper(#api_name); diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index c9b57752e2..8667b64e65 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -269,6 +269,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, workspace, /* workspace */ workspaceSize, stream)); /* stream */ + // Update FP8 scale-inv in output tensor + if (is_fp8_dtype(outputD->data.dtype)) { + update_tensor_scale_inv(outputD, stream); + } + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc)); diff --git a/transformer_engine/common/layer_norm/ln.h b/transformer_engine/common/layer_norm/ln.h index 45839ed75b..13543a10aa 100644 --- a/transformer_engine/common/layer_norm/ln.h +++ b/transformer_engine/common/layer_norm/ln.h @@ -89,6 +89,9 @@ struct FwdParams : public ParamsBase { // AMax output void *amax; + // Inverse of scaling factor + void *scale_inv; + // Whether to compute scale and amax bool fp8_out; }; diff --git a/transformer_engine/common/layer_norm/ln_api.cpp b/transformer_engine/common/layer_norm/ln_api.cpp index 115422e94e..8a40450e59 100644 --- a/transformer_engine/common/layer_norm/ln_api.cpp +++ b/transformer_engine/common/layer_norm/ln_api.cpp @@ -196,6 +196,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size params.epsilon = epsilon; params.amax = z->amax.dptr; params.scale = z->scale.dptr; + params.scale_inv = z->scale_inv.dptr; params.fp8_out = fp8_out; params.zero_centered_gamma = zero_centered_gamma; diff --git a/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh b/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh index 9fe4c16373..bd3741d1d1 100644 --- a/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh +++ b/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh @@ -132,10 +132,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( } } if (params.fp8_out) { - amax = reduce_max(amax, warp); - if (threadIdx.x == 0 && threadIdx.y == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(reinterpret_cast(params.amax), amax); + // Reduce amax over block + if (params.amax != nullptr) { + amax = reduce_max(amax, warp); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(reinterpret_cast(params.amax), amax); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { + reciprocal(reinterpret_cast(params.scale_inv), scale); } } } @@ -291,10 +299,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne // Finalize fp8 factors if (params.fp8_out) { - amax = reduce_max(amax, warp); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(reinterpret_cast(params.amax), amax); + // Reduce amax over block + if (params.amax != nullptr) { + amax = reduce_max(amax, warp); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(reinterpret_cast(params.amax), amax); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { + reciprocal(reinterpret_cast(params.scale_inv), scale); } } } diff --git a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/rmsnorm/rmsnorm_api.cpp index e9a6ff483d..9b143b2f85 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/rmsnorm/rmsnorm_api.cpp @@ -159,6 +159,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens params.epsilon = epsilon; params.amax = z->amax.dptr; params.scale = z->scale.dptr; + params.scale_inv = z->scale_inv.dptr; params.fp8_out = fp8_out; params.zero_centered_gamma = zero_centered_gamma; diff --git a/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh b/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh index a1cfc2293c..c435ae3744 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh +++ b/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh @@ -125,10 +125,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke } } if (params.fp8_out) { - amax = reduce_max(amax, warp); - if (threadIdx.x == 0 && threadIdx.y == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(reinterpret_cast(params.amax), amax); + // Reduce amax over block + if (params.amax != nullptr) { + amax = reduce_max(amax, warp); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(reinterpret_cast(params.amax), amax); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { + reciprocal(reinterpret_cast(params.scale_inv), scale); } } } @@ -267,10 +275,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ // Finalize fp8 factors if (params.fp8_out) { - amax = reduce_max(amax, warp); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(reinterpret_cast(params.amax), amax); + // Reduce amax over block + if (params.amax != nullptr) { + amax = reduce_max(amax, warp); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(reinterpret_cast(params.amax), amax); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { + reciprocal(reinterpret_cast(params.scale_inv), scale); } } } diff --git a/transformer_engine/common/transpose/cast_transpose.cu b/transformer_engine/common/transpose/cast_transpose.cu index 6cbd4daade..dd45d0a668 100644 --- a/transformer_engine/common/transpose/cast_transpose.cu +++ b/transformer_engine/common/transpose/cast_transpose.cu @@ -101,14 +101,11 @@ struct KernelConfig { }; template -__global__ void __launch_bounds__(block_size) - cast_transpose_general_kernel(const IType *__restrict__ const input, - const CType *__restrict__ const noop, - OType *__restrict__ const output_c, - OType *__restrict__ const output_t, - const CType *__restrict__ const scale_ptr, - CType *__restrict__ const amax_ptr, const size_t row_length, - const size_t num_rows) { +__global__ void __launch_bounds__(block_size) cast_transpose_general_kernel( + const IType *__restrict__ const input, const CType *__restrict__ const noop, + OType *__restrict__ const output_c, OType *__restrict__ const output_t, + const CType *__restrict__ const scale_ptr, CType *__restrict__ const amax_ptr, + CType *__restrict__ const scale_inv_ptr, const size_t row_length, const size_t num_rows) { if (noop != nullptr && noop[0] == 1.0f) return; // Vectorized load/store sizes @@ -207,9 +204,15 @@ __global__ void __launch_bounds__(block_size) if (amax_ptr != nullptr) { amax = reduce_max(amax, tidy); if (threadIdx.x == 0) { + static_assert(std::is_same::value); atomicMaxFloat(amax_ptr, amax); } } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv_ptr != nullptr) { + reciprocal(scale_inv_ptr, scale); + } } } // namespace @@ -255,6 +258,8 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output "Cast and transposed outputs need to share amax tensor."); NVTE_CHECK(cast_output.scale.dptr == transposed_output.scale.dptr, "Cast and transposed outputs need to share scale tensor."); + NVTE_CHECK(cast_output.scale_inv.dptr == transposed_output.scale_inv.dptr, + "Cast and transposed outputs need to share scale-inverse tensor."); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, InputType, @@ -324,7 +329,9 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output static_cast(cast_output.data.dptr), static_cast(transposed_output.data.dptr), static_cast(cast_output.scale.dptr), - static_cast(cast_output.amax.dptr), row_length, num_rows); + static_cast(cast_output.amax.dptr), + static_cast(cast_output.scale_inv.dptr), row_length, + num_rows); } else { // Statically-compiled general kernel constexpr size_t load_size = 4; constexpr size_t store_size = 4; @@ -339,7 +346,8 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output static_cast(cast_output.data.dptr), static_cast(transposed_output.data.dptr), static_cast(cast_output.scale.dptr), - static_cast(cast_output.amax.dptr), row_length, num_rows); + static_cast(cast_output.amax.dptr), + static_cast(cast_output.scale_inv.dptr), row_length, num_rows); }); // NOLINT(*) ); // NOLINT(*) } diff --git a/transformer_engine/common/transpose/cast_transpose_fusion.cu b/transformer_engine/common/transpose/cast_transpose_fusion.cu index d96757990f..a8361d57ea 100644 --- a/transformer_engine/common/transpose/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/cast_transpose_fusion.cu @@ -433,15 +433,19 @@ __global__ void __launch_bounds__(cast_transpose_num_threads) } } - /* warp tile amax reduce*/ - amax = reduce_max(amax, warp_id); - - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - if (param.amax != nullptr) { + // Reduce amax over block + if (param.amax != nullptr) { + amax = reduce_max(amax, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); atomicMaxFloat(param.amax, amax); } } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && param.scale_inv != nullptr) { + reciprocal(param.scale_inv, scale); + } } static const char *ActTypeToString[] = { @@ -870,17 +874,18 @@ __global__ void __launch_bounds__(cast_transpose_num_threads) __syncthreads(); } - /* warp tile amax reduce*/ - max = reduce_max(max, warp_id); - - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - if (amax != nullptr) { + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); atomicMaxFloat(amax, max); } - if (scale_inv != nullptr) { - reciprocal(scale_inv, scale); - } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, scale); } } @@ -1079,17 +1084,18 @@ __global__ void __launch_bounds__(cast_transpose_num_threads) __syncthreads(); } - /* warp tile amax reduce*/ - max = reduce_max(max, warp_id); - - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - if (amax != nullptr) { + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); atomicMaxFloat(amax, max); } - if (scale_inv != nullptr) { - reciprocal(scale_inv, scale); - } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, scale); } } diff --git a/transformer_engine/common/transpose/multi_cast_transpose.cu b/transformer_engine/common/transpose/multi_cast_transpose.cu index 8e6e90a7bf..4026016519 100644 --- a/transformer_engine/common/transpose/multi_cast_transpose.cu +++ b/transformer_engine/common/transpose/multi_cast_transpose.cu @@ -36,6 +36,8 @@ struct MultiCastTransposeArgs { void* scale_list[kMaxTensorsPerKernel]; // (output) AMAX's of input tensors void* amax_list[kMaxTensorsPerKernel]; + // (output) Inverse of scaling factor for output tensors + void* scale_inv_list[kMaxTensorsPerKernel]; // Input matrix heights int num_rows_list[kMaxTensorsPerKernel]; // Input matrix widths @@ -82,7 +84,8 @@ __global__ void __launch_bounds__(threads_per_block) OType* output_t = reinterpret_cast(args.output_t_list[tensor_id]); const CType* scale_ptr = reinterpret_cast(args.scale_list[tensor_id]); const CType scale = scale_ptr == nullptr ? 1 : *scale_ptr; - CType* amax = reinterpret_cast(args.amax_list[tensor_id]); + CType* amax_ptr = reinterpret_cast(args.amax_list[tensor_id]); + CType* scale_inv_ptr = reinterpret_cast(args.scale_inv_list[tensor_id]); const int num_rows = args.num_rows_list[tensor_id]; const int row_length = args.row_length_list[tensor_id]; @@ -183,7 +186,10 @@ __global__ void __launch_bounds__(threads_per_block) local_amax = reduce_max(local_amax, tidy); if (tid == 0) { static_assert(std::is_same::value); - if (amax != nullptr) atomicMaxFloat(amax, local_amax); + if (amax_ptr != nullptr) atomicMaxFloat(amax_ptr, local_amax); + } + if (tile_id == 0 && tid == 0 && scale_inv_ptr != nullptr) { + reciprocal(scale_inv_ptr, scale); } } @@ -285,6 +291,7 @@ void multi_cast_transpose(const std::vector input_list, kernel_args.output_t_list[pos] = transposed_output_list[tensor_id]->data.dptr; kernel_args.scale_list[pos] = cast_output_list[tensor_id]->scale.dptr; kernel_args.amax_list[pos] = cast_output_list[tensor_id]->amax.dptr; + kernel_args.scale_inv_list[pos] = cast_output_list[tensor_id]->scale_inv.dptr; kernel_args.num_rows_list[pos] = num_rows; kernel_args.row_length_list[pos] = row_length; kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles; diff --git a/transformer_engine/common/transpose/rtc/cast_transpose.cu b/transformer_engine/common/transpose/rtc/cast_transpose.cu index 6ea8326147..07244a42e9 100644 --- a/transformer_engine/common/transpose/rtc/cast_transpose.cu +++ b/transformer_engine/common/transpose/rtc/cast_transpose.cu @@ -25,7 +25,7 @@ __global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel( const IType* __restrict__ const input, const CType* __restrict__ const noop, OType* __restrict__ const output_c, OType* __restrict__ const output_t, const CType* __restrict__ const scale_ptr, CType* __restrict__ const amax_ptr, - const size_t row_length, const size_t num_rows) { + CType* __restrict__ const scale_inv_ptr, const size_t row_length, const size_t num_rows) { if (noop != nullptr && noop[0] == 1.0f) return; // Vectorized load/store sizes @@ -121,4 +121,9 @@ __global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel( atomicMaxFloat(amax_ptr, amax); } } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv_ptr != nullptr) { + reciprocal(scale_inv_ptr, scale); + } } diff --git a/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu b/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu index c005be98ef..4ba1cb4c69 100644 --- a/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu @@ -229,12 +229,16 @@ __global__ void __launch_bounds__(BLOCK_SIZE) } } - // warp tile amax reduce - const CType max_block = reduce_max(amax, warp_id); - - if (threadIdx.x == 0) { - if (param.amax != nullptr) { + // Reduce amax over block + if (param.amax != nullptr) { + const CType max_block = reduce_max(amax, warp_id); + if (threadIdx.x == 0) { atomicMaxFloat(param.amax, max_block); } } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && param.scale_inv != nullptr) { + reciprocal(param.scale_inv, scale); + } } diff --git a/transformer_engine/common/util/cast.cu b/transformer_engine/common/util/cast.cu index 8dd2b98ebf..dd03afd21b 100644 --- a/transformer_engine/common/util/cast.cu +++ b/transformer_engine/common/util/cast.cu @@ -46,7 +46,8 @@ void fp8_quantize(const Tensor &input, Tensor *output, cudaStream_t stream) { reinterpret_cast(input.data.dptr), reinterpret_cast(output->data.dptr), reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), N, {}, + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, stream);); // NOLINT(*) ); // NOLINT(*) } @@ -68,7 +69,7 @@ void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { p.scale_inv = reinterpret_cast(input.scale_inv.dptr); VectorizedUnaryKernelLauncher( reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), nullptr, nullptr, N, p, + reinterpret_cast(output->data.dptr), nullptr, nullptr, nullptr, N, p, stream);); // NOLINT(*) ); // NOLINT(*) } diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index 63ad1857cf..8653bf45a4 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -168,12 +168,12 @@ template __launch_bounds__(unary_kernel_threads) __global__ void unary_kernel(const InputType *input, OutputType *output, const ComputeType *scale, - ComputeType *amax, Param p, const size_t N, + ComputeType *amax, ComputeType *scale_inv, Param p, const size_t N, const size_t num_aligned_elements) { VectorizedLoader loader(input, N); VectorizedStorer storer(output, N); ComputeType max = 0; - ComputeType s = 0; + ComputeType s = 1; if constexpr (is_fp8::value) { if (scale != nullptr) s = *scale; } @@ -199,12 +199,18 @@ __launch_bounds__(unary_kernel_threads) __global__ storer.store(tid, N); } if constexpr (is_fp8::value) { - /* warp tile amax reduce*/ - max = reduce_max(max, warp_id); + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); + } + } - if (threadIdx.x == 0 && amax != nullptr) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); } } } @@ -214,13 +220,13 @@ template __launch_bounds__(unary_kernel_threads) __global__ void unary_grad_kernel(const InputTypeGrad *grad, const InputType *input, OutputType *output, - const ComputeType *scale, ComputeType *amax, Param p, const size_t N, - const size_t num_aligned_elements) { + const ComputeType *scale, ComputeType *amax, ComputeType *scale_inv, + Param p, const size_t N, const size_t num_aligned_elements) { VectorizedLoader loader(input, N); VectorizedLoader grad_loader(grad, N); VectorizedStorer storer(output, N); ComputeType max = 0; - ComputeType s = 0; + ComputeType s = 1; if constexpr (is_fp8::value) { if (scale != nullptr) s = *scale; } @@ -248,12 +254,18 @@ __launch_bounds__(unary_kernel_threads) __global__ storer.store(tid, N); } if constexpr (is_fp8::value) { - /* warp tile amax reduce*/ - max = reduce_max(max, warp_id); + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); + } + } - if (threadIdx.x == 0 && amax != nullptr) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); } } } @@ -311,7 +323,7 @@ Alignment CheckAlignment(const size_t lead_dim, const int nvec, const T... ptrs) template void VectorizedUnaryKernelLauncher(const InputType *input, OutputType *output, const fp32 *scale, - fp32 *amax, const size_t N, const Param params, + fp32 *amax, fp32 *scale_inv, const size_t N, const Param params, cudaStream_t stream) { if (N != 0) { auto align = CheckAlignment(N, nvec, input, output); @@ -325,16 +337,16 @@ void VectorizedUnaryKernelLauncher(const InputType *input, OutputType *output, c switch (align) { case Alignment::SAME_ALIGNED: unary_kernel<<>>( - input, output, scale, amax, params, N, num_aligned_elements); + input, output, scale, amax, scale_inv, params, N, num_aligned_elements); break; case Alignment::SAME_UNALIGNED: unary_kernel<<>>( - input, output, scale, amax, params, N, num_aligned_elements); + input, output, scale, amax, scale_inv, params, N, num_aligned_elements); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize - unary_kernel<1, true, fp32, Param, OP> - <<>>(input, output, scale, amax, params, N, N); + unary_kernel<1, true, fp32, Param, OP><<>>( + input, output, scale, amax, scale_inv, params, N, N); break; } } @@ -345,7 +357,8 @@ template void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputType *input, OutputType *output, const fp32 *scale, fp32 *amax, - const size_t N, const Param params, cudaStream_t stream) { + fp32 *scale_inv, const size_t N, const Param params, + cudaStream_t stream) { if (N != 0) { auto align = CheckAlignment(N, nvec, input, grad, output); @@ -358,16 +371,16 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputTyp switch (align) { case Alignment::SAME_ALIGNED: unary_grad_kernel<<>>( - grad, input, output, scale, amax, params, N, num_aligned_elements); + grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements); break; case Alignment::SAME_UNALIGNED: unary_grad_kernel<<>>( - grad, input, output, scale, amax, params, N, num_aligned_elements); + grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize - unary_grad_kernel<1, true, fp32, Param, OP> - <<>>(grad, input, output, scale, amax, params, N, N); + unary_grad_kernel<1, true, fp32, Param, OP><<>>( + grad, input, output, scale, amax, scale_inv, params, N, N); break; } } @@ -379,8 +392,8 @@ template __launch_bounds__(unary_kernel_threads) __global__ void gated_act_kernel(const InputType *input, OutputType *output, const ComputeType *scale, - ComputeType *amax, const size_t m, const size_t n, const Param p, - const size_t num_aligned_elements) { + ComputeType *amax, ComputeType *scale_inv, const size_t m, const size_t n, + const Param p, const size_t num_aligned_elements) { const size_t M = num_aligned_elements * m; for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { const size_t id_x = tid % num_aligned_elements; @@ -389,7 +402,7 @@ __launch_bounds__(unary_kernel_threads) __global__ VectorizedLoader loader1(input + id_y * n * 2 + n, n); VectorizedStorer storer(output + id_y * n, n); ComputeType max = 0; - ComputeType s = 0; + ComputeType s = 1; if constexpr (is_fp8::value) { if (scale != nullptr) s = *scale; } @@ -412,12 +425,18 @@ __launch_bounds__(unary_kernel_threads) __global__ storer.store(id_x, n); if constexpr (is_fp8::value) { - /* warp tile amax reduce*/ - max = reduce_max(max, warp_id); + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); + } + } - if (threadIdx.x == 0 && amax != nullptr) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); } } } @@ -427,8 +446,8 @@ template void GatedActivationKernelLauncher(const InputType *input, OutputType *output, const fp32 *scale, - fp32 *amax, const size_t m, const size_t n, const Param &p, - cudaStream_t stream) { + fp32 *amax, fp32 *scale_inv, const size_t m, const size_t n, + const Param &p, cudaStream_t stream) { if (m != 0 && n != 0) { size_t num_aligned_elements = get_num_aligned_elements(input, n, nvec, sizeof(InputType)); constexpr size_t threads = unary_kernel_threads; @@ -439,18 +458,18 @@ void GatedActivationKernelLauncher(const InputType *input, OutputType *output, c switch (auto align = CheckAlignment(n, nvec, input, input + n, output)) { case Alignment::SAME_ALIGNED: gated_act_kernel - <<>>(input, output, scale, amax, m, n, p, + <<>>(input, output, scale, amax, scale_inv, m, n, p, num_aligned_elements); break; case Alignment::SAME_UNALIGNED: gated_act_kernel - <<>>(input, output, scale, amax, m, n, p, + <<>>(input, output, scale, amax, scale_inv, m, n, p, num_aligned_elements); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize gated_act_kernel<1, true, ComputeType, Param, Activation> - <<>>(input, output, scale, amax, m, n, p, n); + <<>>(input, output, scale, amax, scale_inv, m, n, p, n); break; } } diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index bcfc0c608d..6703ce728c 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -852,6 +852,11 @@ __device__ __forceinline__ void reciprocal(T *value_inv, const T value) { *value_inv = 1 / value; } +template <> +__device__ __forceinline__ void reciprocal(float *value_inv, const float value) { + *value_inv = __frcp_rn(value); +} + } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_ diff --git a/transformer_engine/pytorch/cpp_extensions/_common.py b/transformer_engine/pytorch/cpp_extensions/_common.py new file mode 100644 index 0000000000..6ab7d95138 --- /dev/null +++ b/transformer_engine/pytorch/cpp_extensions/_common.py @@ -0,0 +1,87 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Helper functions for C++ extensions""" +import functools +from typing import Dict, Optional, Tuple, Union + +import torch + +import transformer_engine_torch as tex + + +@functools.lru_cache(maxsize=None) +def empty_tensor() -> torch.Tensor: + """Get tensor with no entries and no data""" + return torch.Tensor() + + +def canonicalize_fp8_scales( + *, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + fp8_meta: Optional[tex.FP8TensorMeta] = None, + fp8_meta_index: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None] = None, + allow_multiple_offsets: bool = True, +) -> Tuple[Dict[str, torch.Tensor], Dict[str, int]]: + """Canonicalize FP8 scaling factors (scale, amax, scale-inverse) + + If a scaling factor is not provided, try to access it within the + FP8 meta tensors. Returns dict with tensors and dict with tensor + offsets. + + """ + + # Default: use provided scales with no offsets + scale_offset = 0 + amax_offset = 0 + scale_inv_offset = 0 + + # Get scales from FP8 meta tensors if needed + if (fp8_meta is not None) and any(arg is None for arg in (scale, amax, scale_inv)): + if fp8_meta_index is None: + raise ValueError("Provided `fp8_meta` without corresponding `fp8_meta_index`") + fp8_meta_index = int(fp8_meta_index) + if scale is None: + scale = fp8_meta.scale + scale_offset = fp8_meta_index + if amax is None: + amax = fp8_meta.amax_history + amax_offset = fp8_meta_index + if scale_inv is None: + scale_inv = fp8_meta.scale_inv + scale_inv_offset = fp8_meta_index + + # Construct empty tensors if needed + if scale is None: + scale = empty_tensor() + scale_offset = 0 + if amax is None: + amax = empty_tensor() + amax_offset = 0 + if scale_inv is None: + scale_inv = empty_tensor() + scale_inv_offset = 0 + + # Force offsets to be the same if needed + if not allow_multiple_offsets and not scale_offset == amax_offset == scale_inv_offset: + if scale_offset != 0: + scale = scale[scale_offset] + scale_offset = 0 + if amax_offset != 0: + amax = amax[0][amax_offset] + amax_offset = 0 + if scale_inv_offset != 0: + scale_inv = scale_inv[scale_inv_offset] + scale_inv_offset = 0 + + # Pack tensors and offsets into dicts + tensors = dict(scale=scale, amax=amax, scale_inv=scale_inv) + offsets = dict( + scale_offset=scale_offset, + amax_offset=amax_offset, + scale_inv_offset=scale_inv_offset, + ) + return tensors, offsets diff --git a/transformer_engine/pytorch/cpp_extensions/activation.py b/transformer_engine/pytorch/cpp_extensions/activation.py index 767fe25291..f204982aa0 100644 --- a/transformer_engine/pytorch/cpp_extensions/activation.py +++ b/transformer_engine/pytorch/cpp_extensions/activation.py @@ -3,192 +3,235 @@ # See LICENSE for license information. """Python interface for activation extensions""" -from typing import Union +from typing import Optional, Union + import torch -import transformer_engine_torch as tex +import transformer_engine_torch as tex +from ._common import canonicalize_fp8_scales __all__ = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"] def gelu( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """GeLU with FP8 output""" - empty_tensor = torch.Tensor() - if fp8_meta_tensor is not None: - scale = fp8_meta_tensor.scale - amax_history = fp8_meta_tensor.amax_history - scale_inv = fp8_meta_tensor.scale_inv - else: - scale = empty_tensor - amax_history = empty_tensor - scale_inv = empty_tensor + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel return torch.ops.tex_ts.gelu_ts( inp, - scale, - amax_history, - scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, ) def relu( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ReLU with FP8 output""" - empty_tensor = torch.Tensor() - if fp8_meta_tensor is not None: - scale = fp8_meta_tensor.scale - amax_history = fp8_meta_tensor.amax_history - scale_inv = fp8_meta_tensor.scale_inv - else: - scale = empty_tensor - amax_history = empty_tensor - scale_inv = empty_tensor + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel return torch.ops.tex_ts.relu_ts( inp, - scale, - amax_history, - scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, ) def geglu( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """GeGLU with FP8 output""" - empty_tensor = torch.Tensor() - if fp8_meta_tensor is not None: - scale = fp8_meta_tensor.scale - amax_history = fp8_meta_tensor.amax_history - scale_inv = fp8_meta_tensor.scale_inv - else: - scale = empty_tensor - amax_history = empty_tensor - scale_inv = empty_tensor + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel return torch.ops.tex_ts.geglu_ts( inp, - scale, - amax_history, - scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, ) def reglu( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ReGLU with FP8 output""" - empty_tensor = torch.Tensor() - if fp8_meta_tensor is not None: - scale = fp8_meta_tensor.scale - amax_history = fp8_meta_tensor.amax_history - scale_inv = fp8_meta_tensor.scale_inv - else: - scale = empty_tensor - amax_history = empty_tensor - scale_inv = empty_tensor + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel return torch.ops.tex_ts.reglu_ts( inp, - scale, - amax_history, - scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, ) def swiglu( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """SwiGLU with FP8 output""" - empty_tensor = torch.Tensor() - if fp8_meta_tensor is not None: - scale = fp8_meta_tensor.scale - amax_history = fp8_meta_tensor.amax_history - scale_inv = fp8_meta_tensor.scale_inv - else: - scale = empty_tensor - amax_history = empty_tensor - scale_inv = empty_tensor + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel return torch.ops.tex_ts.swiglu_ts( inp, - scale, - amax_history, - scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, ) def qgelu( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """QuickGELU with FP8 output""" - empty_tensor = torch.Tensor() - if fp8_meta_tensor is not None: - scale = fp8_meta_tensor.scale - amax_history = fp8_meta_tensor.amax_history - scale_inv = fp8_meta_tensor.scale_inv - else: - scale = empty_tensor - amax_history = empty_tensor - scale_inv = empty_tensor + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel return torch.ops.tex_ts.qgelu_ts( inp, - scale, - amax_history, - scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, ) def srelu( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ReLU with FP8 output""" - empty_tensor = torch.Tensor() - if fp8_meta_tensor is not None: - scale = fp8_meta_tensor.scale - amax_history = fp8_meta_tensor.amax_history - scale_inv = fp8_meta_tensor.scale_inv - else: - scale = empty_tensor - amax_history = empty_tensor - scale_inv = empty_tensor + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel return torch.ops.tex_ts.srelu_ts( inp, - scale, - amax_history, - scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, ) diff --git a/transformer_engine/pytorch/cpp_extensions/cast.py b/transformer_engine/pytorch/cpp_extensions/cast.py index 2856d4727b..0c78a65a6c 100644 --- a/transformer_engine/pytorch/cpp_extensions/cast.py +++ b/transformer_engine/pytorch/cpp_extensions/cast.py @@ -4,57 +4,91 @@ """Python interface for cast extensions""" from typing import Optional, Union + import torch -import transformer_engine_torch as tex +import transformer_engine_torch as tex +from ._common import canonicalize_fp8_scales, empty_tensor __all__ = ["cast_to_fp8", "cast_from_fp8"] def cast_to_fp8( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, out: Optional[torch.Tensor] = None, -) -> Optional[torch.Tensor]: + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, +) -> torch.Tensor: """Cast input to FP8""" - if out is not None: - if inp.nelement() > 0: - torch.ops.tex_ts.cast_to_fp8_noalloc_ts( - inp, - fp8_meta_tensor.scale, - out, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor, - otype, - ) - return None - - return torch.ops.tex_ts.cast_to_fp8_ts( - inp, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor, - otype, + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, ) + # Launch FP8 cast kernel + if inp.nelement() == 0: + if out is None: + out = torch.empty_like(inp, dtype=torch.uint8) + elif out is None: + out = torch.ops.tex_ts.cast_to_fp8_ts( + inp, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], + otype, + ) + else: + torch.ops.tex_ts.cast_to_fp8_noalloc_ts( + inp, + fp8_scales["scale"], + out, + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], + otype, + ) + return out + def cast_from_fp8( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], itype: tex.DType, otype: tex.DType, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Cast input from FP8""" + + # Get scaling factors from FP8 meta tensors if needed + scale_inv_offset = 0 + if (fp8_meta_tensor is not None) and (scale_inv is None): + if fp8_tensor is None: + raise ValueError("Provided `fp8_meta_tensor` without corresponding `fp8_tensor`") + scale_inv = fp8_meta_tensor.scale_inv + scale_inv_offset = int(fp8_tensor) + + # Construct empty tensors if needed + if scale_inv is None: + scale_inv = empty_tensor() + scale_inv_offset = 0 + + # Launch FP8 cast kernel return torch.ops.tex_ts.cast_from_fp8_ts( inp, - fp8_meta_tensor.scale_inv, - fp8_tensor, + scale_inv, + scale_inv_offset, itype, otype, ) diff --git a/transformer_engine/pytorch/cpp_extensions/normalization.py b/transformer_engine/pytorch/cpp_extensions/normalization.py index dd90bb0b66..50fd6b7709 100644 --- a/transformer_engine/pytorch/cpp_extensions/normalization.py +++ b/transformer_engine/pytorch/cpp_extensions/normalization.py @@ -4,8 +4,11 @@ """Python interface for normalization extensions""" from typing import Optional, Tuple, Union + import torch + import transformer_engine_torch as tex +from ._common import canonicalize_fp8_scales __all__ = [ @@ -23,46 +26,55 @@ def layernorm_fwd_fp8( weight: torch.Tensor, bias: torch.Tensor, eps: float, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, sm_margin: int, zero_centered_gamma: bool, ln_out: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """LayerNorm with FP8 output""" + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + ) + + # Launch kernel if ln_out is not None: return tex.layernorm_fwd_fp8_noalloc( inp, weight, bias, eps, - fp8_meta_tensor.scale, + fp8_scales["scale"], ln_out, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + fp8_scales["amax"], + fp8_scales["scale_inv"], otype, sm_margin, zero_centered_gamma, - scale_offset=int(fp8_tensor), - amax_offset=int(fp8_tensor), - scale_inv_offset=int(fp8_tensor), + **fp8_scales_offsets, ) - return tex.layernorm_fwd_fp8( inp, weight, bias, eps, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], otype, sm_margin, zero_centered_gamma, - scale_offset=int(fp8_tensor), - amax_offset=int(fp8_tensor), - scale_inv_offset=int(fp8_tensor), + **fp8_scales_offsets, ) @@ -71,26 +83,41 @@ def layernorm_fwd_fp8_inf( weight: torch.Tensor, bias: torch.Tensor, eps: float, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, sm_margin: int, zero_centered_gamma, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """LayerNorm with FP8 output. This version of layernorm_fwd_fp8 is specialized for inference, and returns only the normalized output. """ + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel ret = torch.ops.tex_ts.layernorm_fwd_fp8_inf_ts( inp, weight, bias, eps, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, sm_margin, zero_centered_gamma, @@ -121,44 +148,53 @@ def rmsnorm_fwd_fp8( inp: torch.Tensor, weight: torch.Tensor, eps: float, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, sm_margin: int, zero_centered_gamma: bool, rmsnorm_out: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """RMSNorm with FP8 output""" + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + ) + + # Launch kernel if rmsnorm_out is not None: return tex.rmsnorm_fwd_fp8_noalloc( inp, weight, eps, - fp8_meta_tensor.scale, + fp8_scales["scale"], rmsnorm_out, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + fp8_scales["amax"], + fp8_scales["scale_inv"], otype, sm_margin, zero_centered_gamma, - scale_offset=int(fp8_tensor), - amax_offset=int(fp8_tensor), - scale_inv_offset=int(fp8_tensor), + **fp8_scales_offsets, ) - return tex.rmsnorm_fwd_fp8( inp, weight, eps, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], otype, sm_margin, zero_centered_gamma, - scale_offset=int(fp8_tensor), - amax_offset=int(fp8_tensor), - scale_inv_offset=int(fp8_tensor), + **fp8_scales_offsets, ) @@ -166,25 +202,40 @@ def rmsnorm_fwd_fp8_inf( inp: torch.Tensor, weight: torch.Tensor, eps: float, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, sm_margin: int, zero_centered_gamma, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """RMSNorm with FP8 output. This version of rmsnorm_fwd_fp8 is specialized for inference, and returns only the normalized output. """ + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel ret = torch.ops.tex_ts.rmsnorm_fwd_fp8_inf_ts( inp, weight, eps, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, sm_margin, zero_centered_gamma, diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py index d96b743b9e..37a1b59da2 100644 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ b/transformer_engine/pytorch/cpp_extensions/transpose.py @@ -4,9 +4,12 @@ """Python interface for transpose extensions""" from typing import List, Optional, Tuple, Union + import torch + import transformer_engine_torch as tex from ..constants import TE_DType +from ._common import canonicalize_fp8_scales, empty_tensor __all__ = [ @@ -20,83 +23,115 @@ def fp8_cast_transpose_fused( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, cast_out: Optional[torch.Tensor] = None, transpose_out: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, noop_flag: Optional[torch.Tensor] = None, -) -> Union[Tuple[torch.Tensor, torch.Tensor], None]: +) -> Tuple[torch.Tensor, torch.Tensor]: """Cast + Transpose with FP8 output""" - return_outputs = False + # Allocate outputs if needed if transpose_out is None: transpose_out = torch.empty(inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8) - return_outputs = True if cast_out is None: cast_out = torch.empty_like(inp, dtype=torch.uint8) - return_outputs = True + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + ) + + # Construct no-op flag if needed if noop_flag is None: - noop_flag = torch.Tensor() + noop_flag = empty_tensor() + # Launch kernel if needed if inp.nelement() > 0: tex.fused_cast_transpose_noop( inp, noop_flag, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], cast_out, transpose_out, otype, - scale_offset=int(fp8_tensor), - amax_offset=int(fp8_tensor), - scale_inv_offset=int(fp8_tensor), + **fp8_scales_offsets, ) - if return_outputs: - return cast_out, transpose_out - return None + return cast_out, transpose_out def fp8_cast_transpose_bgrad_fused( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Cast + Transpose + BGRAD with FP8 output""" + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + ) + + # Launch kernel return tex.fused_cast_transpose_bgrad( inp, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], otype, - scale_offset=int(fp8_tensor), - amax_offset=int(fp8_tensor), - scale_inv_offset=int(fp8_tensor), + **fp8_scales_offsets, ) def fp8_transpose_bgrad_fused( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, grad_bias_type: torch.dtype, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Transpose + BGRAD with FP8 output""" + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + ) + + # Launch kernel return tex.fused_fp8_transpose_bgrad( inp, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], otype, TE_DType[grad_bias_type], - scale_offset=int(fp8_tensor), - amax_offset=int(fp8_tensor), - scale_inv_offset=int(fp8_tensor), + **fp8_scales_offsets, ) @@ -106,18 +141,30 @@ def fp8_cast_transpose_bgrad_dgelu_fused( fp8_meta_tensor: tex.FP8TensorMeta, fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Cast + Transpose + BGRAD + DGELU with FP8 output""" + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + ) + + # Launch kernel return tex.fused_cast_transpose_bgrad_dgelu( grad_output, gelu_input, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], otype, - scale_offset=int(fp8_tensor), - amax_offset=int(fp8_tensor), - scale_inv_offset=int(fp8_tensor), + **fp8_scales_offsets, ) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index b7f87ad397..d531979868 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -117,13 +117,6 @@ def forward( scale_inv: Optional[torch.Tensor] = None, ) -> Float8Tensor: - # Manually compute scale-inverse if needed - if scale is not None and scale_inv is None: - if isinstance(scale, torch.Tensor): - scale_inv = scale.reciprocal() - else: - scale_inv = 1 / scale - # Extract data from FP8 meta tensors if provided if fp8_meta is not None: fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( @@ -138,9 +131,6 @@ def forward( scale = fp8_meta[fp8_meta_key].scale[fp8_meta_index] if amax is None: amax = fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] - if scale_inv is None: - scale_inv = fp8_meta[fp8_meta_key].scale_inv[fp8_meta_index] - scale_inv = scale_inv.detach().view(1).clone() # Check input tensor tensor = tensor.contiguous().cuda().detach() @@ -163,8 +153,9 @@ def forward( # Check scale-inverse if scale_inv is None: - scale_inv = scale.reciprocal() - scale_inv = scale_inv.to(device=tensor.device, dtype=torch.float32) + scale_inv = torch.empty([1], dtype=torch.float32, device=tensor.device) + else: + scale_inv = scale_inv.to(device=tensor.device, dtype=torch.float32) # Check amax if amax is None: @@ -737,19 +728,9 @@ def cast_transpose_( self._fp8_dtype, cast_out=data, transpose_out=transpose, + scale_inv=self._scale_inv, noop_flag=noop_flag, ) - scale = fp8_meta.scale[fp8_meta_index : fp8_meta_index + 1] - scale_inv = self._scale_inv - if noop_flag is None: - torch.reciprocal(scale, out=scale_inv) - else: - torch.where( - noop_flag.bool(), - scale_inv, - scale.reciprocal(), - out=scale_inv, - ) self._transpose_invalid = False @torch.no_grad() @@ -853,7 +834,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): fp8_meta_index = dst._fp8_meta_index scale = dst._fp8_meta[fp8_meta_key].scale[fp8_meta_index] amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] - dst._scale_inv.copy_(scale.detach().reciprocal()) # Cast to FP8 if not dst._data.is_contiguous(): diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 281e3fe104..23a06e318f 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -52,6 +52,9 @@ def _apply_normalization( fwd_ln_sm_margin: int, zero_centered_gamma: bool, is_grad_enabled: bool, + fp8_scale: Optional[torch.Tensor] = None, + fp8_amax: Optional[torch.Tensor] = None, + fp8_scale_inv: Optional[torch.Tensor] = None, ): normalization_func = _get_normalization_func(normalization, fp8_out, is_grad_enabled, True) @@ -70,6 +73,9 @@ def _apply_normalization( fp8_dtype_forward, fwd_ln_sm_margin, zero_centered_gamma, + scale=fp8_scale, + amax=fp8_amax, + scale_inv=fp8_scale_inv, **output_kwarg, ) else: @@ -82,6 +88,9 @@ def _apply_normalization( fp8_dtype_forward, fwd_ln_sm_margin, zero_centered_gamma, + scale=fp8_scale, + amax=fp8_amax, + scale_inv=fp8_scale_inv, ), None, None, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 10560cdad6..d6045d8e77 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -46,6 +46,7 @@ from ..graph import is_graph_capturing from ._common import _apply_normalization, _noop_cat from ..float8_tensor import Float8Tensor +from ..export import is_in_onnx_export_mode __all__ = ["LayerNormLinear"] @@ -126,8 +127,13 @@ def forward( inputmat, dtype=ln_out_dtype, memory_format=torch.contiguous_format ) + # Objects for FP8 cast fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + ln_out_scale_inv = None + if fp8: + ln_out_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device) + # Launch normalization kernel ln_out, mu, rsigma = _apply_normalization( inputmat, ln_out, @@ -140,6 +146,7 @@ def forward( fwd_ln_sm_margin, zero_centered_gamma, is_grad_enabled, + fp8_scale_inv=ln_out_scale_inv, ) # Column Parallel Linear @@ -172,6 +179,7 @@ def forward( tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, out=ln_out_fp8, + scale_inv=ln_out_scale_inv, ) ln_out = torch.empty_like(ln_out_fp8) else: @@ -180,6 +188,7 @@ def forward( fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, + scale_inv=ln_out_scale_inv, ) if ln_out_gathered: rank = torch.distributed.get_rank(tp_group) @@ -199,6 +208,18 @@ def forward( assert isinstance(weight_fp8, Float8Tensor) + # Hack for ONNX export + # Note: ONNX models are represented as a graph of tensor + # operations, so the in-place scale-inv update doesn't fit + # very well. We work around this by making it look like + # the scale-inv tensor is initialized with a copy. + # Note: ONNX export expects FP8 scales can be represented + # with constant ops. However, copying into a buffer + # involves an expand op for array broadcasting. We work + # around this by filling the buffer instead. + if is_in_onnx_export_mode(): + ln_out_scale_inv.fill_(ln_out_scale_inv.item()) + if fp8_meta["recipe"].fp8_mha: out_index, meta_tensor, output_te_dtype, output_dtype = ( tex.FP8FwdTensors.GEMM1_OUTPUT, @@ -219,8 +240,8 @@ def forward( 0, weight_fp8._fp8_dtype, ln_out_total, - fp8_meta["scaling_fwd"].scale_inv, - tex.FP8FwdTensors.GEMM1_INPUT, + ln_out_scale_inv, + 0, fp8_dtype_forward, output_dtype, get_workspace(), @@ -306,7 +327,7 @@ def forward( weight_fp8, weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, ln_out if weight.requires_grad else None, - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, + ln_out_scale_inv, ) ctx.activation_dtype = activation_dtype @@ -377,7 +398,7 @@ def backward( weight_fp8, main_grad, ln_out, - fwd_scale_inverses, + ln_out_scale_inv, ) = ctx.saved_tensors # Gather intermediate/activation tensors if needed @@ -570,8 +591,8 @@ def backward( ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) wgrad, _ = tex.fp8_gemm( ln_out_total_t, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM1_INPUT, + ln_out_scale_inv, + 0, fp8_dtype_forward, ( grad_output_t._data @@ -596,8 +617,8 @@ def backward( else: ln_out_total_c = torch.ops.tex_ts.cast_from_fp8_ts( ln_out_total, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM1_INPUT, + ln_out_scale_inv, + 0, fp8_dtype_forward, TE_DType[ctx.activation_dtype], ) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 68d333262d..175e5ab5cf 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -48,6 +48,7 @@ from ..jit import no_torch_dynamo from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor +from ..export import is_in_onnx_export_mode __all__ = ["Linear"] @@ -103,10 +104,12 @@ def forward( inputmat = cast_if_needed(inputmat, activation_dtype) inputmat_t = None inputmat_no_fp8 = inputmat + inputmat_scale_inv = None if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) if isinstance(inputmat, Float8Tensor): + inputmat_scale_inv = inputmat._scale_inv if ( not fp8_meta["recipe"].override_linear_precision.wgrad and is_grad_enabled @@ -116,6 +119,7 @@ def forward( # FP8 input for forward, FP8 input transpose for backward wgrad inputmat_t = inputmat.transpose_2d() else: + inputmat_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device) if ( not fp8_meta["recipe"].override_linear_precision.wgrad and is_grad_enabled @@ -128,6 +132,7 @@ def forward( fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, + scale_inv=inputmat_scale_inv, ) else: # FP8 input for forward @@ -136,8 +141,21 @@ def forward( fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, + scale_inv=inputmat_scale_inv, ) + # Hack for ONNX export + # Note: ONNX models are represented as a graph of tensor + # operations, so the in-place scale-inv update doesn't fit + # very well. We work around this by making it look like + # the scale-inv tensor is initialized with a copy. + # Note: ONNX export expects FP8 scales can be represented + # with constant ops. However, copying into a buffer + # involves an expand op for array broadcasting. We work + # around this by filling the buffer instead. + if is_in_onnx_export_mode(): + inputmat_scale_inv.fill_(inputmat_scale_inv.item()) + # Column Parallel Linear if parallel_mode == "column" and sequence_parallel: inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) @@ -206,8 +224,8 @@ def forward( if isinstance(inputmat_total, Float8Tensor) else inputmat_total ), - fp8_meta["scaling_fwd"].scale_inv, - tex.FP8FwdTensors.GEMM1_INPUT, + inputmat_scale_inv, + 0, fp8_dtype_forward, proj_out_pttype, get_workspace(), @@ -312,10 +330,10 @@ def forward( ctx.save_for_backward( saved_inputmat, saved_inputmat_t, + inputmat_scale_inv, weight, weight_fp8, weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, ) ctx.activation_dtype = activation_dtype @@ -364,10 +382,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ( inputmat, inputmat_t, + inputmat_scale_inv, weight, weight_fp8, main_grad, - fwd_scale_inverses, ) = ctx.saved_tensors # Gather intermediate/activation tensors if needed @@ -520,8 +538,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if isinstance(inputmat_t_total, Float8Tensor) else inputmat_t_total ), - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM1_INPUT, + inputmat_scale_inv, + 0, fp8_dtype_forward, grad_output_t, ctx.fp8_meta["scaling_bwd"].scale_inv, diff --git a/transformer_engine/pytorch/te_onnx_extensions.py b/transformer_engine/pytorch/te_onnx_extensions.py index 05c1a5a0f5..0fa9401163 100755 --- a/transformer_engine/pytorch/te_onnx_extensions.py +++ b/transformer_engine/pytorch/te_onnx_extensions.py @@ -74,7 +74,7 @@ def is_dtype_bf16(t): return t.type().scalarType() == "BFloat16" -def quantize(g, inputs, scale_inv, fp8_tensor): +def quantize(g, inputs, scale, fp8_tensor): """Helper Function for Quantization""" output_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs) @@ -83,7 +83,7 @@ def quantize(g, inputs, scale_inv, fp8_tensor): if not is_dtype_fp32(inputs): inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT) - scale = g.op("Constant", value_t=torch.tensor(scale_inv[fp8_tensor])) + scale = g.op("Constant", value_t=torch.tensor(1 / scale[fp8_tensor])) q_op = g.op(make_op_name("TRT_FP8QuantizeLinear"), inputs, scale).setType( inputs.type().with_dtype(torch.uint8).with_sizes(output_shape) ) @@ -124,18 +124,18 @@ def compute_in_fp32(g, inp, subgraph, *args, **kwargs): return sg_out -@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") +@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") def onnx_cast_to_fp8(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for cast_to_fp8""" # pylint: disable=unused-argument - return quantize(g, inputs, scale_inv, fp8_tensor) + return quantize(g, inputs, scale, fp8_tensor) -@symbolic_helper.parse_args("v", "v", "v", "v", "fs", "i", "i") +@symbolic_helper.parse_args("v", "fs", "v", "v", "v", "i", "i") def onnx_cast_to_fp8_noalloc(g, inputs, scale, output, amax, scale_inv, fp8_tensor, otype): """ONNX graph for cast_to_fp8_noalloc""" # pylint: disable=unused-argument - return quantize(g, inputs, scale_inv, fp8_tensor) + return quantize(g, inputs, scale, fp8_tensor) @symbolic_helper.parse_args("v", "fs", "i", "i", "i") @@ -145,25 +145,25 @@ def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype): return dequantize(g, inputs, scale_inv, fp8_tensor, otype) -@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") +@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_gelu""" # pylint: disable=unused-argument # TE computes GELU using float32 precision so wrap the GELU subgraph with # conversion to/from float32. gelu = compute_in_fp32(g, inputs, torch.onnx.symbolic_opset9.gelu, "tanh") - if scale_inv: - gelu = quantize(g, gelu, scale_inv, fp8_tensor) + if scale: + gelu = quantize(g, gelu, scale, fp8_tensor) return gelu -@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") +@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") def onnx_fp8_relu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_relu""" # pylint: disable=unused-argument relu = compute_in_fp32(g, inputs, torch.onnx.symbolic_opset9.relu) - if scale_inv: - relu = quantize(g, relu, scale_inv, fp8_tensor) + if scale: + relu = quantize(g, relu, scale, fp8_tensor) return relu @@ -178,13 +178,13 @@ def onnx_swiglu(g: jit_utils.GraphContext, inp, dim): return g.op("Mul", g.op("Sigmoid", first), second) -@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") +@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") def onnx_fp8_swiglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_swiglu""" # pylint: disable=unused-argument swiglu = compute_in_fp32(g, inputs, onnx_swiglu, 1) - if scale_inv: - swiglu = quantize(g, swiglu, scale_inv, fp8_tensor) + if scale: + swiglu = quantize(g, swiglu, scale, fp8_tensor) return swiglu @@ -199,13 +199,13 @@ def onnx_reglu(g: jit_utils.GraphContext, inp, dim): return g.op("Mul", g.op("Relu", first), second) -@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") +@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") def onnx_fp8_reglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_reglu""" # pylint: disable=unused-argument reglu = compute_in_fp32(g, inputs, onnx_reglu, 1) - if scale_inv: - reglu = quantize(g, reglu, scale_inv, fp8_tensor) + if scale: + reglu = quantize(g, reglu, scale, fp8_tensor) return reglu @@ -221,13 +221,13 @@ def onnx_geglu(g: jit_utils.GraphContext, inp, dim): return g.op("Mul", first_gelu, second) -@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") +@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") def onnx_fp8_geglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_geglu""" # pylint: disable=unused-argument geglu = compute_in_fp32(g, inputs, onnx_geglu, 1) - if scale_inv: - geglu = quantize(g, geglu, scale_inv, fp8_tensor) + if scale: + geglu = quantize(g, geglu, scale, fp8_tensor) return geglu @@ -245,7 +245,7 @@ def onnx_fp8_geglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): "v", "fs", "i", - "fs", + "v", "v", "i", "v", @@ -330,7 +330,7 @@ def _ones_like(g, inp, dtype): return one -@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i", "i", "b") +@symbolic_helper.parse_args("v", "v", "v", "f", "fs", "v", "v", "i", "i", "i", "b") def onnx_layernorm_fwd_fp8( g, inputs, @@ -355,7 +355,7 @@ def onnx_layernorm_fwd_fp8( bias = g.op("Cast", bias, to_i=inp_dtype) ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_gamma) - fp8_ln = quantize(g, ln, scale_inv, fp8_tensor) + fp8_ln = quantize(g, ln, scale, fp8_tensor) return fp8_ln @@ -391,7 +391,7 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_ga return ln -@symbolic_helper.parse_args("v", "v", "f", "v", "v", "fs", "i", "i", "i", "b") +@symbolic_helper.parse_args("v", "v", "f", "fs", "v", "v", "i", "i", "i", "b") def onnx_rmsnorm_fwd_fp8( g, inputs, @@ -413,7 +413,7 @@ def onnx_rmsnorm_fwd_fp8( weight = g.op("Cast", weight, to_i=inp_dtype) ln = onnx_rmsnorm_fwd(g, inputs, weight, eps, sm_margin, zero_centered_gamma) - fp8_ln = quantize(g, ln, scale_inv, fp8_tensor) + fp8_ln = quantize(g, ln, scale, fp8_tensor) return fp8_ln From 525de6cc65d0a06ef13ba6cf67680796456475d3 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 20 Aug 2024 20:14:41 -0700 Subject: [PATCH 06/50] Update cudnn-frontend to v1.6.1 (#1108) * update FE to 1.6 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update to 1.6.1-rc for testing Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update to fe 1.6.1 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- .../common/fused_attn/fused_attn_fp8.cu | 30 +++++++++++++++---- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 98ca4e1941..2533f5e5c1 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 98ca4e1941fe3263f128f74f10063a3ea35c7019 +Subproject commit 2533f5e5c1877fd76266133c1479ef1643ce3a8b diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index bda3f5beba..fb7765e1a4 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1835,8 +1835,14 @@ void fused_attn_fp8_fwd_impl_v1( generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride); - amax_o->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - amax_s->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + amax_o->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + amax_s->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); Stats->set_output(true) .set_data_type(fe::DataType_t::FLOAT) @@ -2182,10 +2188,22 @@ void fused_attn_fp8_bwd_impl_v1( dQ->set_output(true).set_dim({b, h, s_q, d}).set_stride(q_stride); dK->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(k_stride); dV->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(v_stride); - amax_dQ->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - amax_dK->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - amax_dV->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - amax_dP->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + amax_dQ->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + amax_dK->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + amax_dV->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + amax_dP->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); dO->set_data_type(bwd_tensor_type); dQ->set_data_type(bwd_tensor_type); From 26c8fcc92dfd43ea19e308a97062211aab711066 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Date: Tue, 20 Aug 2024 20:17:03 -0700 Subject: [PATCH 07/50] Add FP8 support to CP implementation with KV P2P (#1114) * add window_size to AttnFuncWithCP Signed-off-by: Xiaowei Ren * add seq_offsets_qkvo for cudnn thd Signed-off-by: Xiaowei Ren * add seq_offsets_qkvo to AttnFuncWithCP Signed-off-by: Xiaowei Ren * fix seq_offsets calculation of cudnn thd Signed-off-by: Xiaowei Ren * remove a thd assert Signed-off-by: Xiaowei Ren * fix bias for thd test Signed-off-by: Xiaowei Ren * add thd test for cudnn FA with CP Signed-off-by: Xiaowei Ren * skip GQA/MQA test for cuDNN THD Signed-off-by: Xiaowei Ren * make sure seq_offsets are computed with qkv_group of hd_hd_hd while CP>1 Signed-off-by: Xiaowei Ren * fix seq_offsets inputs Signed-off-by: Xiaowei Ren * remove two comments Signed-off-by: Xiaowei Ren * fix attn mask type for cudnn thd with cp Signed-off-by: Xiaowei Ren * fix attn_mask_type check Signed-off-by: Xiaowei Ren * fix attn_mask_type for cudnn fa with thd Signed-off-by: Xiaowei Ren * fix a typo Signed-off-by: Xiaowei Ren * fix out dout in bwd Signed-off-by: Xiaowei Ren * assert cudnn+thd does not support attn bias Signed-off-by: Xiaowei Ren * check if attn_mask_type has padding Signed-off-by: Xiaowei Ren * minor change Signed-off-by: Xiaowei Ren * change cp test batch size to 2 Signed-off-by: Xiaowei Ren * fix code format Signed-off-by: Xiaowei Ren * fix two assert info Signed-off-by: Xiaowei Ren * fix assert comment Signed-off-by: Xiaowei Ren * fix assert comments Signed-off-by: Xiaowei Ren * minor fix Signed-off-by: Xiaowei Ren * fix assert comments Signed-off-by: Xiaowei Ren * assert swa+CP cannot work with thd format Signed-off-by: Xiaowei Ren * add a new CP function for swa Signed-off-by: Xiaowei Ren * add a missing dgrads Signed-off-by: Xiaowei Ren * minor change Signed-off-by: Xiaowei Ren * add draft fwd function for swa+cp Signed-off-by: Xiaowei Ren * minor change Signed-off-by: Xiaowei Ren * enable flash attention for swa+cp Signed-off-by: Xiaowei Ren * remove an assert of swa+cp Signed-off-by: Xiaowei Ren * call SWAFuncWithCP for swa+cp Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * use 2hd layout Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change qkv_format check Signed-off-by: Xiaowei Ren * add a code comment Signed-off-by: Xiaowei Ren * tensor shape bug fix Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tensor shape fix Signed-off-by: Xiaowei Ren * add function to compute cu_seqlens of a cp rank Signed-off-by: Xiaowei Ren * add cu_seqlens and cu_seqlens_padded to context parallelism Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * minor change Signed-off-by: Xiaowei Ren * fix FlashAttention output sequence length Signed-off-by: Xiaowei Ren * fix cu_seqlens_kv_per_step calculation Signed-off-by: Xiaowei Ren * zero dQKV for ending padded tokens Signed-off-by: Xiaowei Ren * zero dQKV tensors of FlashAttention Signed-off-by: Xiaowei Ren * fix softmax_lse correction Signed-off-by: Xiaowei Ren * remove padded tokens of KV to save comounication Signed-off-by: Xiaowei Ren * do not need to zero dkv for FlashAttention any mroe Signed-off-by: Xiaowei Ren * zero out tensors Signed-off-by: Xiaowei Ren * remove redundant code Signed-off-by: Xiaowei Ren * fix CP unit test Signed-off-by: Xiaowei Ren * fix kv shape of cp test with thd format Signed-off-by: Xiaowei Ren * update cp unit test Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add simple code framework Signed-off-by: Xiaowei Ren * try not to have a separate CP function for SWA Signed-off-by: Xiaowei Ren * backup some code change Signed-off-by: Xiaowei Ren * back up code Signed-off-by: Xiaowei Ren * clean up fwd implementation of SWAFuncWithCP Signed-off-by: Xiaowei Ren * remove redundant code Signed-off-by: Xiaowei Ren * code cleaning Signed-off-by: Xiaowei Ren * fix assert info Signed-off-by: Xiaowei Ren * reduce kv chunk concat overheads Signed-off-by: Xiaowei Ren * minor change Signed-off-by: Xiaowei Ren * make AttnFuncWithCP and SWAFuncWithCP have same API Signed-off-by: Xiaowei Ren * add a docstring Signed-off-by: Xiaowei Ren * preliminary implementation of SWAFuncWithCP forward seems working Signed-off-by: Xiaowei Ren * fix output shape of SWAFuncWithCP Signed-off-by: Xiaowei Ren * code refactoring for FlashAttention and add a code placeholder for bwd Signed-off-by: Xiaowei Ren * use gather_along_first_dim Signed-off-by: Xiaowei Ren * finish the preliminary implementation of bwd Signed-off-by: Xiaowei Ren * remove redundant code Signed-off-by: Xiaowei Ren * fix assert condition Signed-off-by: Xiaowei Ren * add draft implementation of SWA+CP with FusedAttention Signed-off-by: Xiaowei Ren * fix attention mask type of swa+cp Signed-off-by: Xiaowei Ren * code cleaning Signed-off-by: Xiaowei Ren * add qkv_layout Signed-off-by: Xiaowei Ren * add missing window_size argument Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * fix kv shape of swa+cp Signed-off-by: Xiaowei Ren * bug and typo fix Signed-off-by: Xiaowei Ren * fix dout shape Signed-off-by: Xiaowei Ren * add multi stream in fwd of swa+cp Signed-off-by: Xiaowei Ren * save chunk_ids_to_kv_ag in fwd Signed-off-by: Xiaowei Ren * add multi stream in bwd of swa+cp Signed-off-by: Xiaowei Ren * minor fix to cp stream sync Signed-off-by: Xiaowei Ren * rename AttnFuncWithCP Signed-off-by: Xiaowei Ren * check if window size is None Signed-off-by: Xiaowei Ren * fix docstring of AttnFuncWithCP Signed-off-by: Xiaowei Ren * minor fix Signed-off-by: Xiaowei Ren * add env var for users to choose KV ag or KV p2p Signed-off-by: Xiaowei Ren * update cp tests Signed-off-by: Xiaowei Ren * fix window size in cp unit test Signed-off-by: Xiaowei Ren * fix pytest skip messages Signed-off-by: Xiaowei Ren * add cp_comm_type into API Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * code cleaning Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add deterministic konb in cuDNN fused attn backend Signed-off-by: Xiaowei Ren * pass fp8 and fp8_meta to attn_func_with_cp Signed-off-by: Xiaowei Ren * assert only Fused Attn can support FP8+CP Signed-off-by: Xiaowei Ren * remove redundant assert Signed-off-by: Xiaowei Ren * add a fwd draft implementation of FP8 + CP Signed-off-by: Xiaowei Ren * save fp8 and fp8_meta Signed-off-by: Xiaowei Ren * assert sequence length divisible requirements Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove a redundant qkv_layout compute Signed-off-by: Xiaowei Ren * if condition change Signed-off-by: Xiaowei Ren * some typo fix Signed-off-by: Xiaowei Ren * add support table of context parallelism Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * typo and code format fix Signed-off-by: Xiaowei Ren * do not print multiple disabling messages Signed-off-by: Xiaowei Ren * bug fix Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix aux_ctx_tensors of FP8 Signed-off-by: Xiaowei Ren * bug fix Signed-off-by: Xiaowei Ren * fix device in torch.arange and adjust code for the PR of MLA Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * commit code change for FP8+CP Signed-off-by: Xiaowei Ren * commit more code change for FP8+CP Signed-off-by: Xiaowei Ren * commit more fp8 code for FP8+CP Signed-off-by: Xiaowei Ren * bug fixes Signed-off-by: Xiaowei Ren * bug fix Signed-off-by: Xiaowei Ren * cast merged CP results from FP32 to BF16 Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * minor change Signed-off-by: Xiaowei Ren * fix softmax_lse Signed-off-by: Xiaowei Ren * fix some bugs of FP8 dkv exchange Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * add FP8 unit test Signed-off-by: Xiaowei Ren * fix typos and clean asserts Signed-off-by: Xiaowei Ren * fix get_p2p_comm_info Signed-off-by: Xiaowei Ren * fix dkv p2p exchange Signed-off-by: Xiaowei Ren * minor fix Signed-off-by: Xiaowei Ren * change FP8 dkv P2P to A2A Signed-off-by: Xiaowei Ren * add FP8+CP unit test Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * assert amax reduction is needed for FP8+CP Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove duplicated code Signed-off-by: Xiaowei Ren * destroy process group in CP unit test Signed-off-by: Xiaowei Ren * remove interval from fp8_recipe because it has been deprecated Signed-off-by: Xiaowei Ren * try to fix the failed CP test with the latest CI pipeline Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove redundant f before string Signed-off-by: Xiaowei Ren * change META_O_CP Signed-off-by: Xiaowei Ren --------- Signed-off-by: Xiaowei Ren Co-authored-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xiaowei Ren --- .../fused_attn/run_fused_attn_with_cp.py | 147 ++-- .../fused_attn/test_fused_attn_with_cp.py | 12 +- transformer_engine/pytorch/attention.py | 696 ++++++++++++------ 3 files changed, 592 insertions(+), 263 deletions(-) diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index 2433a8a09d..6c775fb127 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -2,15 +2,18 @@ # # See LICENSE for license information. -import os, sys +import os, sys, logging +from contextlib import nullcontext import torch import torch.distributed as dist from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import get_cu_seqlens_on_cp_rank import transformer_engine_torch as tex from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn +from transformer_engine.pytorch.fp8 import fp8_autocast +from transformer_engine.common.recipe import DelayedScaling -dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} +dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} def run_dpa_with_cp( @@ -57,6 +60,9 @@ def run_dpa_with_cp( assert rank in cp_comm_ranks cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") + if dtype == "fp8": + fp8_recipe = DelayedScaling(fp8_dpa=True) + # instantiate core attn module core_attn = DotProductAttention( config.num_heads, @@ -171,18 +177,27 @@ def run_dpa_with_cp( # run core_attn without CP for x in [q, k, v]: x.requires_grad = True - out = core_attn( - q, - k, - v, - core_attention_bias_type=config.attn_bias_type, - core_attention_bias=bias, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], - cu_seqlens_kv_padded=None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1], - ) - out.backward(dout) + + if dtype == "fp8": + fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) + else: + fp8_context = nullcontext() + + with fp8_context: + out = core_attn( + q, + k, + v, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias=bias, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], + cu_seqlens_kv_padded=( + None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] + ), + ) + out.backward(dout) # run core_attn wit CP q_, k_, v_, dout_, *rest = [ @@ -226,31 +241,34 @@ def run_dpa_with_cp( core_attn.set_context_parallel_group( cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type ) - out_ = core_attn( - q_, - k_, - v_, - core_attention_bias_type=config.attn_bias_type, - core_attention_bias=bias_, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], - cu_seqlens_kv_padded=None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1], - ) - out_.backward(dout_) + + if dtype == "fp8": + core_attn.reset_fp8_meta_tensors() + fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) + else: + fp8_context = nullcontext() + + with fp8_context: + out_ = core_attn( + q_, + k_, + v_, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias=bias_, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], + cu_seqlens_kv_padded=( + None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] + ), + ) + out_.backward(dout_) for x in [out_, q_.grad, k_.grad, v_.grad]: assert torch.all(~torch.isnan(x)) assert torch.all(~torch.isinf(x)) # compare results with and without CP - tols = dict(atol=5e-3, rtol=5e-3) - if dtype == "bf16": - if config.num_heads == config.num_gqa_groups: - tols = dict(atol=2.5e-2, rtol=2.5e-2) - else: - tols = dict(atol=3.5e-2, rtol=3.5e-2) - if qkv_format == "bshd" or qkv_format == "sbhd": dq, dk, dv, out = [ x.view( @@ -309,32 +327,55 @@ def run_dpa_with_cp( else: assert False, f"{qkv_format} is an unsupported qkv_format!" + if dtype == "bf16": + if config.num_heads == config.num_gqa_groups: + tols = dict(atol=2.5e-2, rtol=2.5e-2) + else: + tols = dict(atol=3.5e-2, rtol=3.5e-2) + elif dtype == "fp16": + tols = dict(atol=5e-3, rtol=5e-3) + elif dtype == "fp8": + tols = dict(atol=5e-1, rtol=5e-1) + rmse_tol = 0.1 + else: + assert False, f"{dtype} is an unsupported dtype!" + + def _rmse(a, b): + return torch.sqrt((a - b).square().mean()).item() + + def _error(a, b): + if dtype != "fp8": + torch.testing.assert_close(a, b, **tols) + else: + try: + torch.testing.assert_close(a, b, **tols) + except Exception as e: + logging.debug(e) + + rmse = _rmse(a, b) + rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) + assert ( + rmse < rmse_tol * rmse_range + ), "RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( + rmse, rmse_tol * rmse_range, rmse_tol, rmse_range + ) + if qkv_format == "bshd": - torch.testing.assert_close(out_[:, 0], out[:, 0], **tols) - torch.testing.assert_close(dq_[:, 0], dq[:, 0], **tols) - torch.testing.assert_close(dk_[:, 0], dk[:, 0], **tols) - torch.testing.assert_close(dv_[:, 0], dv[:, 0], **tols) - torch.testing.assert_close(out_[:, 1], out[:, 1], **tols) - torch.testing.assert_close(dq_[:, 1], dq[:, 1], **tols) - torch.testing.assert_close(dk_[:, 1], dk[:, 1], **tols) - torch.testing.assert_close(dv_[:, 1], dv[:, 1], **tols) + for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): + _error(a[:, 0], b[:, 0]) + _error(a[:, 1], b[:, 1]) elif qkv_format == "sbhd": - torch.testing.assert_close(out_[0], out[0], **tols) - torch.testing.assert_close(dq_[0], dq[0], **tols) - torch.testing.assert_close(dk_[0], dk[0], **tols) - torch.testing.assert_close(dv_[0], dv[0], **tols) - torch.testing.assert_close(out_[1], out[1], **tols) - torch.testing.assert_close(dq_[1], dq[1], **tols) - torch.testing.assert_close(dk_[1], dk[1], **tols) - torch.testing.assert_close(dv_[1], dv[1], **tols) + for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): + _error(a[0], b[0]) + _error(a[1], b[1]) elif qkv_format == "thd": - torch.testing.assert_close(out_, out, **tols) - torch.testing.assert_close(dq_, dq, **tols) - torch.testing.assert_close(dk_, dk, **tols) - torch.testing.assert_close(dv_, dv, **tols) + for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): + _error(a, b) else: assert False, f"{qkv_format} is an unsupported qkv_format!" + dist.destroy_process_group() + def main(**kwargs): run_dpa_with_cp(**kwargs) diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 0074d18cec..82875e2791 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -90,7 +90,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") -@pytest.mark.parametrize("dtype", ["bf16", "fp16"]) +@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"]) @pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) @pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"]) @@ -121,8 +121,16 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): ) if config.window_size != (-1, 0) and config.window_size != (-1, -1): pytest.skip( - f"Fused attention does not support sliding window attention + context parallelism yet!" + "Fused attention does not support sliding window attention + context parallelism yet!" + ) + if cp_comm_type == "all_gather" and dtype == "fp8": + pytest.skip( + "CP implementation with KV all-gather does not support FP8 + context parallelism yet!" ) + if dtype == "fp8" and qkv_format == "thd": + pytest.skip("FP8 attention cannot work with THD format yet!") + if dtype == "fp8" and config.attn_bias_type != "no_bias": + pytest.skip("FP8 attention cannot work with bias yet!") subprocess.run( get_bash_arguments( diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 71bc15fdad..8fac4778c8 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -95,6 +95,9 @@ META_DO = tex.FP8BwdTensors.GRAD_INPUT2 META_S = tex.FP8FwdTensors.GEMM3_OUTPUT META_DP = tex.FP8BwdTensors.GRAD_INPUT3 +# repurpose some unused amax history buffers for partial results of CP fwd and bwd +META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT +META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1 # NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) @@ -654,18 +657,6 @@ def get_attention_backend( logger.debug("Disabling FusedAttention as no backend supports the provided input") use_fused_attention = False fused_attention_backend = None - if ( - use_fused_attention - and context_parallel - and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"] - ): - logger.debug( - "Disabling FusedAttention as only sub-backend %s does not support " - "context parallellism", - int(fused_attention_backend), - ) - use_fused_attention = False - fused_attention_backend = None if ( use_fused_attention and window_size is not None @@ -1322,6 +1313,8 @@ def forward( attn_bias, deterministic, use_fused_attention, + fp8, + fp8_meta, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -1407,6 +1400,43 @@ def forward( # synchronize fwd results correction across steps fwd_results_correction_done = torch.cuda.Event() + if fp8: + if use_fused_attention: + fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + fused_attn_qkv_dtype = fp8_dtype_forward + fused_attn_backend = FusedAttnBackend["FP8"] + if fp8_meta["recipe"].fp8_mha: + assert ( + isinstance(q, Float8Tensor) + and isinstance(k, Float8Tensor) + and isinstance(v, Float8Tensor) + ), "q/k/v must be Float8Tensors for FP8 MHA!" + fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv + q_fp8, k_fp8, v_fp8 = q, k, v + q, k, v = q_fp8._data, k_fp8._data, v_fp8._data + else: + q_f16, k_f16, v_f16 = q, k, v + q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + k, v = [ + cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + for x in [k_f16, v_f16] + ] + fp8_meta_kwargs = {} + fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv[META_QKV] + fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv[META_S] + fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale[META_S] + fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale[META_O_CP] + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + else: + assert False, "FP8 is only supported with Fused Attention!" + else: + q_f16 = q + if use_fused_attention: + fp8_meta_kwargs = {} + fused_attn_qkv_dtype = TE_DType[q.dtype] + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + p2p_comm_buffers = [None for _ in range(cp_size)] if use_fused_attention and qkv_format in ["bshd", "sbhd"]: p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3) @@ -1433,7 +1463,23 @@ def forward( batch_p2p_comm, ) - kv_inputs[i % 2] = p2p_comm_buffers[i] + if ( + not fp8 + or fp8_meta["recipe"].fp8_mha + or int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ): + kv_inputs[i % 2] = p2p_comm_buffers[i] + else: + # KV exchange is in BF16/FP16, cast received KV in each step + kv_inputs[i % 2] = cast_to_fp8( + p2p_comm_buffers[i], + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + ) + if fp8 and use_fused_attention: + fp8_meta_kwargs["amax_s"] = amax_per_step[0][i] + fp8_meta_kwargs["amax_o"] = amax_per_step[1][i] if causal: if i == 0: if pad_between_seqs_q: @@ -1474,38 +1520,40 @@ def forward( ), dim=-1, ).contiguous() - out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = ( - fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), - TE_DType[q.dtype], - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type=attn_mask_type, - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - ) + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + q_inputs[i % 2], + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + fused_attn_qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + **fp8_meta_kwargs, ) - if len(rest) > 0: - attn_biases[i] = rest[0] + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + attn_biases[i] = rest[0] if len(rest) > 0 else None else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) @@ -1572,42 +1620,44 @@ def forward( if attn_bias is not None: idx = (rank - i) % cp_size attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous() - out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = ( - fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_kv // 2, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), - TE_DType[q.dtype], - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=( - None - if cu_seqlens_kv_padded is None - else cu_seqlens_kv_padded // 2 - ), - ) + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv // 2, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + q_inputs[i % 2], + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + fused_attn_qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type="padding" if padding else "no_mask", + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=( + None + if cu_seqlens_kv_padded is None + else cu_seqlens_kv_padded // 2 + ), + **fp8_meta_kwargs, ) - if len(rest) > 0: - attn_biases[i] = rest[0] + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + attn_biases[i] = rest[0] if len(rest) > 0 else None else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) @@ -1693,42 +1743,44 @@ def forward( ), dim=-1, ).contiguous() - out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = ( - fused_attn_fwd( - is_training, - max_seqlen_q // 2, - max_seqlen_kv, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), - TE_DType[q.dtype], - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=( - None - if cu_seqlens_q_padded is None - else cu_seqlens_q_padded // 2 - ), - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - ) + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q // 2, + max_seqlen_kv, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + q_inputs[i % 2], + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + fused_attn_qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type="padding" if padding else "no_mask", + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + cu_seqlens_q_padded=( + None + if cu_seqlens_q_padded is None + else cu_seqlens_q_padded // 2 + ), + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + **fp8_meta_kwargs, ) - if len(rest) > 0: - attn_biases[i] = rest[0] + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + attn_biases[i] = rest[0] if len(rest) > 0 else None else: if qkv_format == "thd": # [t, np, hn] -> [t/2, np, hn] @@ -1795,38 +1847,40 @@ def forward( ), dim=-1, ).contiguous() - out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = ( - fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q, - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), - TE_DType[q.dtype], - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type=attn_mask_type, - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - ) + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + q, + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + fused_attn_qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + **fp8_meta_kwargs, ) - if len(rest) > 0: - attn_biases[i] = rest[0] + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + attn_biases[i] = rest[0] if len(rest) > 0 else None else: # [b, sq, np, hn] -> [b*sq, np, hn] q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) @@ -1866,8 +1920,16 @@ def forward( softmax_lse_per_step[i - 1].squeeze_(-1) with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): + if fp8: + out_per_step[i - 1] = cast_from_fp8( + out_per_step[i - 1], + fp8_meta["scaling_fwd"], + META_O_CP, + fp8_dtype_forward, + TE_DType[torch.float32], + ) if i == 1: - out = torch.zeros_like(q) + out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape) softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) if causal and qkv_format != "thd": # [b, np, sq] -> [b, np, 2, sq//2] @@ -1951,13 +2013,55 @@ def forward( else: out = out.view(-1, *out.shape[-2:]) + if fp8 and use_fused_attention: + amax_cp_fwd = amax_per_step.amax(dim=1) + fp8_meta["scaling_fwd"].amax_history[0][META_S] = amax_cp_fwd[0] + fp8_meta["scaling_fwd"].amax_history[0][META_O_CP] = amax_cp_fwd[1] + + out_f16 = out.to(q_fp8.dtype if fp8 and fp8_meta["recipe"].fp8_mha else q_f16.dtype) + if fp8 and (fp8_meta["recipe"].fp8_mha or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))): + out_fp8 = cast_to_fp8(out_f16, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward) + + if fp8 and fp8_meta["recipe"].fp8_mha: + out_ret = Float8Tensor( + data=out_fp8, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=META_O, + fp8_dtype=fp8_dtype_forward, + dtype=q_fp8.dtype, + ) + else: + out_ret = out_f16 + + if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q_save, kv_save, out_save = q, kv, out_fp8 + fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone() + fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone() + elif fp8 and fp8_meta["recipe"].fp8_mha: + kv_fp8 = Float8Tensor( + data=kv, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=META_QKV, + fp8_dtype=fp8_dtype_forward, + dtype=k_fp8.dtype, + ) + q_save, kv_save, out_save = q_fp8, kv_fp8, out_f16 + fp8_fwd_scales, fp8_fwd_scale_invs = None, None + else: + q_save, kv_save, out_save = q_f16, kv, out_f16 + fp8_fwd_scales, fp8_fwd_scale_invs = None, None + ctx.save_for_backward( - q, - kv, - out, + q_save, + kv_save, + out_save, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, + fp8_fwd_scales, + fp8_fwd_scale_invs, *cu_seqlens_q_per_step, *cu_seqlens_kv_per_step, *rng_states, @@ -1976,7 +2080,9 @@ def forward( ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention - return out + ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ctx.fp8_meta = fp8_meta + return out_ret @staticmethod def backward(ctx, dout): @@ -1987,10 +2093,11 @@ def backward(ctx, dout): batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6] - cu_seqlens_q_per_step = ctx.saved_tensors[6 : 6 + cp_size] - cu_seqlens_kv_per_step = ctx.saved_tensors[6 + cp_size : 6 + cp_size * 2] - rng_states = ctx.saved_tensors[6 + cp_size * 2 : 6 + cp_size * 3] - attn_biases = ctx.saved_tensors[6 + cp_size * 3 : 6 + cp_size * 4] + (fp8_fwd_scales, fp8_fwd_scale_invs) = ctx.saved_tensors[6:8] + cu_seqlens_q_per_step = ctx.saved_tensors[8 : 8 + cp_size] + cu_seqlens_kv_per_step = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2] + rng_states = ctx.saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3] + attn_biases = ctx.saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4] causal = "causal" in ctx.attn_mask_type padding = "padding" in ctx.attn_mask_type @@ -2025,22 +2132,60 @@ def backward(ctx, dout): if ctx.use_fused_attention: # [b, np, sq//2] -> [b, np, sq//2, 1] softmax_lse_.unsqueeze_(-1) - if ctx.use_fused_attention: # [b, np, sq] -> [b, np, sq, 1] softmax_lse.unsqueeze_(-1) + + if ctx.fp8: + if ctx.use_fused_attention: + fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) + fused_attn_qkv_dtype = fp8_dtype_backward + fused_attn_dqkv_dtype = fp8_dtype_backward + fused_attn_backend = FusedAttnBackend["FP8"] + dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device) + dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device) + dkv_fp8_ = torch.empty_like(dkv_fp8) + dout_dtype = dout.dtype + if ctx.fp8_meta["recipe"].fp8_mha: + assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" + ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv + dout = dout._data + else: + dout = cast_to_fp8( + dout, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward + ) + p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]] + fp8_meta_kwargs = {} + fp8_meta_kwargs["d_scale_qkv"] = fp8_fwd_scale_invs[META_QKV] + fp8_meta_kwargs["d_scale_s"] = fp8_fwd_scale_invs[META_S] + fp8_meta_kwargs["d_scale_o"] = fp8_fwd_scale_invs[META_O] + fp8_meta_kwargs["d_scale_do"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] + fp8_meta_kwargs["d_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP] + fp8_meta_kwargs["q_scale_s"] = fp8_fwd_scales[META_S] + fp8_meta_kwargs["q_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale[META_DP] + fp8_meta_kwargs["q_scale_dqkv"] = ctx.fp8_meta["scaling_bwd"].scale[META_DQKV_CP] + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + else: + assert False, "FP8 is only supported with Fused Attention!" + else: + if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha: + q, kv, dout = [x.from_float8(x.dtype) for x in [q, kv, dout]] + dq = torch.empty_like(q) + if ctx.qkv_format == "thd" and causal: + dq[cu_seqlens_q_padded[-1] :].fill_(0) + p2p_comm_buffers = [ + torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), + torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), + ] + p2p_comm_buffers[0][0].copy_(kv) + if ctx.use_fused_attention: + fp8_meta_kwargs = {} + fused_attn_qkv_dtype = TE_DType[q.dtype] + fused_attn_dqkv_dtype = TE_DType[q.dtype] + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + out = out.view(*q.shape) dout = dout.view(*q.shape) - # Flash Attn outputs - dq = torch.empty_like(q) - if ctx.qkv_format == "thd" and causal: - dq[cu_seqlens_q_padded[-1] :].fill_(0) - - p2p_comm_buffers = [ - torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), - torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), - ] - p2p_comm_buffers[0][0].copy_(kv) send_recv_reqs = [] fa_optional_backward_kwargs = {} @@ -2056,18 +2201,40 @@ def backward(ctx, dout): send_tensor = p2p_comm_buffers[i % 2] recv_tensor = p2p_comm_buffers[(i + 1) % 2] - if i == 0: - send_tensor = send_tensor[0] - recv_tensor = recv_tensor[0] - if i == (cp_size - 1): - send_tensor = send_tensor[1] - recv_tensor = recv_tensor[1] - - send_recv_reqs = flash_attn_p2p_communicate( - rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm - ) + if ctx.fp8: + if i < cp_size - 1: + send_recv_reqs = flash_attn_p2p_communicate( + rank, + send_tensor[0], + send_dst, + recv_tensor[0], + recv_src, + ctx.cp_group, + batch_p2p_comm, + ) + else: + dkv_a2a_req = torch.distributed.all_to_all_single( + dkv_fp8, + dkv_fp8_, + group=ctx.cp_group, + async_op=True, + ) + send_recv_reqs = [dkv_a2a_req] + else: + if i == 0: + send_tensor = send_tensor[0] + recv_tensor = recv_tensor[0] + if i == (cp_size - 1): + send_tensor = send_tensor[1] + recv_tensor = recv_tensor[1] + send_recv_reqs = flash_attn_p2p_communicate( + rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm + ) kv = p2p_comm_buffers[i % 2][0] + if ctx.fp8 and ctx.use_fused_attention: + fp8_meta_kwargs["amax_dp"] = amax_per_step[0][i] + fp8_meta_kwargs["amax_dqkv"] = amax_per_step[0][i] # In reversed order of fwd if causal: if i == (cp_size - 1): @@ -2090,7 +2257,14 @@ def backward(ctx, dout): dout_ = dout.view(-1, *dout.shape[-3:]) elif ctx.qkv_format == "thd": q_, kv_, out_, dout_ = q, kv, out, dout - aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] + if ctx.fp8: + aux_ctx_tensors = [ + softmax_lse, + softmax_lse, + rng_states[cp_size - i - 1], + ] + else: + aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( @@ -2103,10 +2277,10 @@ def backward(ctx, dout): kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, dout_, - TE_DType[q.dtype], - TE_DType[kv.dtype], + fused_attn_qkv_dtype, + fused_attn_dqkv_dtype, aux_ctx_tensors, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, attn_scale=ctx.softmax_scale, @@ -2114,6 +2288,8 @@ def backward(ctx, dout): qkv_layout=qkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, ) else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] @@ -2169,7 +2345,14 @@ def backward(ctx, dout): q_, out_, dout_ = q, out, dout # [2, t, np, hn] -> [2, t/2, np, hn] kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) - aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] + if ctx.fp8: + aux_ctx_tensors = [ + softmax_lse, + softmax_lse, + rng_states[cp_size - i - 1], + ] + else: + aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( @@ -2182,10 +2365,10 @@ def backward(ctx, dout): kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, dout_, - TE_DType[q.dtype], - TE_DType[kv.dtype], + fused_attn_qkv_dtype, + fused_attn_dqkv_dtype, aux_ctx_tensors, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=( None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2 @@ -2195,6 +2378,8 @@ def backward(ctx, dout): qkv_layout=qkv_layout, attn_mask_type="padding" if padding else "no_mask", attn_bias_type=ctx.attn_bias_type, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, ) else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] @@ -2256,7 +2441,14 @@ def backward(ctx, dout): out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1) dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1) kv_ = kv - aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]] + if ctx.fp8: + aux_ctx_tensors = [ + softmax_lse_, + softmax_lse_, + rng_states[cp_size - i - 1], + ] + else: + aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( @@ -2269,10 +2461,10 @@ def backward(ctx, dout): kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, dout_, - TE_DType[q.dtype], - TE_DType[kv.dtype], + fused_attn_qkv_dtype, + fused_attn_dqkv_dtype, aux_ctx_tensors, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + fused_attn_backend, cu_seqlens_q_padded=( None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2 ), @@ -2282,6 +2474,8 @@ def backward(ctx, dout): qkv_layout=qkv_layout, attn_mask_type="padding" if padding else "no_mask", attn_bias_type=ctx.attn_bias_type, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, ) else: if ctx.qkv_format == "thd": @@ -2325,7 +2519,10 @@ def backward(ctx, dout): ) else: if ctx.use_fused_attention: - aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] + if ctx.fp8: + aux_ctx_tensors = [softmax_lse, softmax_lse, rng_states[cp_size - i - 1]] + else: + aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( @@ -2338,10 +2535,10 @@ def backward(ctx, dout): kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1], out, dout, - TE_DType[q.dtype], - TE_DType[kv.dtype], + fused_attn_qkv_dtype, + fused_attn_dqkv_dtype, aux_ctx_tensors, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, attn_scale=ctx.softmax_scale, @@ -2349,6 +2546,8 @@ def backward(ctx, dout): qkv_layout=qkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, ) else: # [b, sq, np, hn] -> [b*sq, np, hn] @@ -2383,6 +2582,8 @@ def backward(ctx, dout): **fa_optional_backward_kwargs, ) + if ctx.fp8: + dq = dq_fp8[(rank + i + 1) % cp_size] if i >= (cp_size - rank - 1) or not causal: # [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal # [b*sq, np, hn] -> [b, sq, np, hn] if not causal @@ -2395,7 +2596,17 @@ def backward(ctx, dout): # [b*sq//2, np, hn] -> [sq//2, b, np, hn] dq_ = dq_.view(-1, *dq.shape[-3:]) - if causal: + if ctx.fp8: + if i >= (cp_size - rank - 1) or not causal: + dq.copy_(dq_) + else: + if ctx.qkv_format == "bshd": + dq[:, 0, ...].fill_(0) + dq[:, 1, ...].copy_(dq_) + elif ctx.qkv_format == "sbhd": + dq[0].fill_(0) + dq[1].copy_(dq_) + elif causal: if i > (cp_size - rank - 1): dq.add_(dq_) elif i == (cp_size - rank - 1): @@ -2450,7 +2661,13 @@ def backward(ctx, dout): for req in send_recv_reqs: req.wait() - dkv = p2p_comm_buffers[(i + 1) % 2][1] + if ctx.fp8: + if i < cp_size - 1: + dkv = dkv_fp8_[(rank + i + 1) % cp_size] + else: + dkv = dkv_fp8[(rank + i + 1) % cp_size] + else: + dkv = p2p_comm_buffers[(i + 1) % 2][1] if ctx.use_fused_attention: dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0) if ctx.qkv_format in ["bshd", "sbhd"]: @@ -2469,7 +2686,17 @@ def backward(ctx, dout): # [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal dkv_ = dkv_.view(*dkv.shape) - if causal: + if ctx.fp8: + if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): + if ctx.qkv_format == "bshd": + dkv[:, :, 0, ...].copy_(dkv_) + dkv[:, :, 1, ...].fill_(0) + elif ctx.qkv_format == "sbhd": + dkv[:, 0, ...].copy_(dkv_) + dkv[:, 1, ...].fill_(0) + else: + dkv.copy_(dkv_) + elif causal: if i == (cp_size - 1): if rank == 0: if ctx.qkv_format == "bshd": @@ -2507,6 +2734,26 @@ def backward(ctx, dout): else: dkv.add_(dkv_) + if ctx.fp8 and ctx.use_fused_attention: + amax_cp_bwd = amax_per_step.amax(dim=1) + ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP] = amax_cp_bwd[0] + ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV_CP] = amax_cp_bwd[1] + if ctx.qkv_format in ["bshd", "sbhd"]: + # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or + # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn] + dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:]) + dq, dkv = [ + cast_from_fp8( + x, + ctx.fp8_meta["scaling_bwd"], + META_DQKV_CP, + fp8_dtype_backward, + TE_DType[torch.float32], + ) + for x in [dq_fp8, dkv_fp8] + ] + dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]] + if causal: if ctx.qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] @@ -2527,6 +2774,25 @@ def backward(ctx, dout): dkv_[:, cu_seqlens_kv_padded[-1] :].fill_(0) dkv = dkv_ + if ctx.fp8 and ctx.fp8_meta["recipe"].fp8_mha: + dq, dkv = [ + cast_to_fp8(x, ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward) + for x in [dq, dkv] + ] + dq, dk, dv = [ + Float8Tensor( + data=x, + fp8_meta=ctx.fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=META_DQKV, + fp8_dtype=fp8_dtype_backward, + dtype=dout_dtype, + ) + for x in [dq, dkv[0], dkv[1]] + ] + else: + dk, dv = dkv[0], dkv[1] + if attn_dbias is not None: # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk] attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1) @@ -2534,8 +2800,8 @@ def backward(ctx, dout): return ( None, dq, - dkv[0], - dkv[1], + dk, + dv, None, None, None, @@ -2553,12 +2819,14 @@ def backward(ctx, dout): attn_dbias, None, None, + None, + None, ) -@jit_fuser +@torch.compile def get_seq_chunk_ids_to_all_gathered_kv( - local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size_left + local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size_left, device ): """Compute sequence chunk ids to the all-gathered KV.""" seq_end_idx = (local_chunk_id + 1) * max_seqlen_kv @@ -2569,7 +2837,7 @@ def get_seq_chunk_ids_to_all_gathered_kv( local_chunk_id - num_chunks + 1, local_chunk_id + 1, dtype=torch.int32, - device="cuda", + device=device, ) chunk_ids_to_all_gathered_kv = torch.where( chunk_ids < cp_size, 2 * chunk_ids, 2 * (2 * cp_size - chunk_ids) - 1 @@ -2683,6 +2951,7 @@ def forward( if (window_size is None or window_size[0] == -1) else window_size[0] ), + k.device, ) chunk_ids_to_kv_ag_per_step[i] = chunk_ids_to_kv_ag num_kv_chunks = chunk_ids_to_kv_ag.numel() @@ -3029,6 +3298,8 @@ def attn_forward_func_with_cp( deterministic=False, use_fused_attention=False, window_size=None, + fp8=False, + fp8_meta=None, ) -> torch.Tensor: """ Attention implementation with context parallelism. @@ -3109,6 +3380,8 @@ def attn_forward_func_with_cp( attn_bias, deterministic, use_fused_attention, + fp8, + fp8_meta, ) else: raise ValueError(f"Unsupported communication type: {cp_comm_type}!") @@ -5638,9 +5911,21 @@ def forward( and (fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen) ) + if fp8: + assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, ( + f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}" + " is required for FP8 attention!" + ) + assert fp8_meta is not None, "FP8 metadata fp8_meta is required for FP8 attention!" + assert not context_parallel or fp8_meta["recipe"].reduce_amax, ( + "Amax reduction across TP+CP group is necessary when using context parallelism with" + " FP8!" + ) + if context_parallel: assert ( - fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + fp8 + or fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen ), f"{fused_attention_backend} does not work with context parallelism!" assert core_attention_bias_type not in [ "alibi" @@ -5670,19 +5955,14 @@ def forward( attn_mask_type=attn_mask_type, attn_bias_type=core_attention_bias_type, attn_bias=core_attention_bias, + deterministic=self.deterministic, use_fused_attention=True, window_size=window_size, + fp8=fp8, + fp8_meta=fp8_meta, ) else: with self.attention_dropout_ctx(): - if fp8: - assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, ( - f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}" - " is required for FP8 attention!" - ) - assert ( - fp8_meta is not None - ), "FP8 metadata fp8_meta is required for FP8 attention!" output = FusedAttnFunc.apply( self.training, max_seqlen_q, From 467b39a3aed56987c4b0ec60eef1935b994fd7da Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 20 Aug 2024 20:56:19 -0700 Subject: [PATCH 08/50] [PyTorch] Add support for padding mask in `UnfusedDotProductAttention` (#1073) * add support for padding in UnfusedDPA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add support for padding_causal/_bottom_right Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix padding_causal/_bottom_right Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * need to test max512 backend Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix mask logic in unfused Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use actual_seqlen for alibi/causal_bottom_right padding Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes and convert causal to causal_bottom_right for inference Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use causal in kv cache inference test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * simplify get_alibi logic Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * simplify the non-padding path for get_alibi Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * avoid batch_size loop in generating padding_causal/_bottom_right masks Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_numerics.py | 6 +- .../common/fused_attn/fused_attn.cpp | 5 +- transformer_engine/pytorch/attention.py | 172 +++++++++++++----- transformer_engine/pytorch/softmax.py | 39 ++-- transformer_engine/pytorch/transformer.py | 2 +- 5 files changed, 155 insertions(+), 69 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a2023f539a..85cd4fc256 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1655,8 +1655,8 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ffn_hidden_size=4 * D, num_attention_heads=H, attn_input_format=input_format, - self_attn_mask_type="causal_bottom_right", - enc_dec_attn_mask_type="causal_bottom_right", + self_attn_mask_type="causal", + enc_dec_attn_mask_type="causal", layer_number=layer_number, attention_dropout=0.0, params_dtype=dtype, @@ -1670,7 +1670,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, qkv_format=input_format, layer_number=layer_number, attention_dropout=0.0, - attn_mask_type="causal_bottom_right", + attn_mask_type="causal", params_dtype=dtype, ) .cuda() diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 0fe62f8cb4..70f1fa409f 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -142,7 +142,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || (bias_type == NVTE_Bias_Type::NVTE_ALIBI && attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && sm_arch_ >= 90) || + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && + sm_arch_ >= 90) || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) || ((cudnn_runtime_version >= 90000) && (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) && diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 8fac4778c8..6a46d6c3c1 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -472,19 +472,25 @@ def get_attention_backend( use_fused_attention = False # Filter: Attention mask - # attn_mask_type | supported backends - # ------------------------------------------------------------------- - # no_mask | All - # padding | FlashAttention, FusedAttention - # causal | - # self-attention | All - # cross-attention | FusedAttention - # padding_causal | - # self-attention | FlashAttention, FusedAttention - # cross-attention | FusedAttention - # causal_bottom_right | All - # padding_causal_bottom_right | FlashAttention, FusedAttention - # arbitrary | UnfusedDotProductAttention + # attn_mask_type | attention_mask | supported backends + # ---------------------------------------------------------------------------------------- + # no_mask | None | All + # padding | | All + # self-attention | One tensor in shape [b, 1, 1, sq] | + # cross-attention | Tuple of two tensors in shapes | + # | [b, 1, 1, sq] and [b, 1, 1, skv] | + # causal | None | + # self-attention | | All + # cross-attention | | FusedAttention, UnfusedDotProductAttention + # padding_causal | Same as "padding" | + # self-attention | | All + # cross-attention | | FusedAttention, UnfusedDotProductAttention + # causal_bottom_right | None | All + # padding_causal_bottom_right | Same as "padding" | + # self-attention | | All + # cross-attention | | FlashAttention, UnfusedDotProductAttention + # arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention + # | [b, h, sq, skv] | if attn_mask_type == "arbitrary": if use_flash_attention: logger.debug("Disabling FlashAttention for arbitrary mask") @@ -492,9 +498,6 @@ def get_attention_backend( if use_fused_attention: logger.debug("Disabling FusedAttention for arbitrary mask") use_fused_attention = False - if use_unfused_attention and "padding" in attn_mask_type: - logger.debug("Disabling UnfusedDotProductAttention for %s mask", attn_mask_type) - use_unfused_attention = False if ( use_flash_attention and _flash_attn_2_1_plus @@ -780,7 +783,7 @@ def get_attention_backend( class InferenceParams: # pylint: disable=too-few-public-methods """ Inference parameters that are passed to the main model in order - to efficienly calculate and store the context during inference. + to efficiently calculate and store the context during inference. Parameters ---------- @@ -886,6 +889,8 @@ def get_alibi( num_heads: int, max_seqlen_q: int, max_seqlen_kv: int, + actual_seqlens_q: Optional[torch.Tensor] = None, + actual_seqlens_kv: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, bias_dtype: Optional[torch.dtype] = None, bottom_right_alignment: bool = True, @@ -899,6 +904,10 @@ def get_alibi( Maximum sequence length for queries. max_seqlen_kv: int Maximum sequence length for keys and values. + actual_seqlens_q: Optional[torch.Tensor], default = `None` + Actual sequence lengths for queries, in shape [batch_size]. + actual_seqlens_kv: Optional[torch.Tensor], default = `None` + Actual sequence lengths for keys and values, in shape [batch_size]. alibi_slopes: Optional[torch.Tensor], default = `None` Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads]. bias_dtype: Optional[torch.dtype], default = `None` @@ -912,10 +921,12 @@ def get_alibi( alibi_slopes: torch.Tensor ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads]. alibi_bias: torch.Tensor - ALiBi bias in FP32 or `bias_dtype`. If `alibi_slopes` is in [num_heads] shape, - then `alibi_bias` is in [1, num_heads, max_seqlen_q, max_seqlen_kv], and if - `alibi_slopes` is in [batch_size, num_heads], then the bias is in - [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. + ALiBi bias in FP32 or `bias_dtype`. Its shape is + (1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape, + and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or + (2) [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in + [batch_size, num_heads] shape, or, if `alibi_slopes` is in [num_heads] shape and + `actual_seqlens_q` and `actual_seqlens_kv` are not `None`. """ global _alibi_cache if _alibi_cache["_alibi_slopes_require_update"]: @@ -941,17 +952,23 @@ def get_alibi( slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1]) if _alibi_cache["_alibi_slopes"].dim() == 2: slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1]) - if bottom_right_alignment: - bias = torch.arange(1 - max_seqlen_kv, 1, dtype=torch.int32, device="cuda").view( - 1, 1, 1, max_seqlen_kv - ) - else: - bias = torch.arange( - 1 - max_seqlen_q, max_seqlen_kv - max_seqlen_q + 1, dtype=torch.int32, device="cuda" - ).view(1, 1, 1, max_seqlen_kv) - bias = bias - torch.arange(1 - max_seqlen_q, 1, dtype=torch.int32, device="cuda").view( + bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( 1, 1, max_seqlen_q, 1 + ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view( + 1, 1, 1, max_seqlen_kv ) + if actual_seqlens_q is None and actual_seqlens_kv is None: + if bottom_right_alignment: + bias = bias + max_seqlen_kv - max_seqlen_q + elif actual_seqlens_q is not None and actual_seqlens_kv is not None: + batch_size = actual_seqlens_q.shape[0] + bias = bias.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + if bottom_right_alignment: + bias = bias + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1) + else: + assert ( + False + ), "actual_seqlens_q and actual_seqlens_kv need to be both None or torch.Tensors!" bias = bias.abs().mul(-1) bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape) _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv @@ -3705,6 +3722,7 @@ class UnfusedDotProductAttention(torch.nn.Module): def __init__( self, softmax_scale: float, + attention_type: str = "self", attention_dropout: float = 0.0, attention_dropout_ctx: Optional[Callable] = nullcontext, layer_number: Optional[int] = None, @@ -3712,6 +3730,7 @@ def __init__( super().__init__() self.softmax_scale = softmax_scale + self.attention_type = attention_type self.attention_dropout_ctx = attention_dropout_ctx self.layer_number = layer_number @@ -3751,6 +3770,58 @@ def forward( query_layer, key_layer, value_layer = [ x.transpose(0, 1) for x in [query_layer, key_layer, value_layer] ] + batch_size, max_seqlen_q, max_seqlen_kv = ( + query_layer.shape[1], + query_layer.shape[0], + key_layer.shape[0], + ) + if "padding" in attn_mask_type: + if self.attention_type == "self": + assert attention_mask.shape == ( + batch_size, + 1, + 1, + max_seqlen_q, + ), "attention_mask should be a single tensor with [b, 1, 1, sq] shape!" + attention_mask = torch.logical_or( + attention_mask.squeeze(1).unsqueeze(3), attention_mask + ) + else: + assert ( + len(attention_mask) == 2 + and attention_mask[0].shape == (batch_size, 1, 1, max_seqlen_q) + and attention_mask[1].shape == (batch_size, 1, 1, max_seqlen_kv) + ), ( + "attention_mask should be a tuple of two tensors with shapes " + "[b, 1, 1, sq] and [b, 1, 1, skv]!" + ) + attention_mask = torch.logical_or( + attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1] + ) + mask = attention_mask.squeeze(1).logical_not() + actual_seqlens_q = mask[:, :, 0].sum(dim=1) + actual_seqlens_kv = mask[:, 0, :].sum(dim=1) + mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( + 1, 1, max_seqlen_q, 1 + ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view( + 1, 1, 1, max_seqlen_kv + ) + if attn_mask_type == "padding_causal": + attention_mask = torch.logical_or( + torch.where(mask.view(1, 1, max_seqlen_q, max_seqlen_kv) < 0, 1, 0), + attention_mask, + ) + if attn_mask_type == "padding_causal_bottom_right": + attention_mask = torch.logical_or( + torch.where( + mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1) + < 0, + 1, + 0, + ), + attention_mask, + ) batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 @@ -3805,7 +3876,7 @@ def forward( key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] beta=0.0, alpha=scale, - ) + ).view(*output_size) elif core_attention_bias_type == "pre_scale_bias": assert core_attention_bias is not None, "core_attention_bias should not be None!" @@ -3813,10 +3884,7 @@ def forward( query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] ) - matmul_result = ( - matmul_result.view(output_size[0], output_size[1], output_size[2], output_size[3]) - + core_attention_bias - ).view(-1, output_size[2], output_size[3]) + matmul_result = matmul_result.view(*output_size) + core_attention_bias matmul_result *= scale elif core_attention_bias_type in ["post_scale_bias", "alibi"]: @@ -3827,6 +3895,8 @@ def forward( output_size[1], output_size[2], output_size[3], + actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None, + actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None, alibi_slopes=alibi_slopes, bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], ) @@ -3837,26 +3907,21 @@ def forward( beta=0.0, alpha=scale, ) - matmul_result = ( - ( - matmul_result.view( - output_size[0], output_size[1], output_size[2], output_size[3] - ) - + core_attention_bias - ) - .view(-1, output_size[2], output_size[3]) - .to(dtype=query_layer.dtype) + matmul_result = (matmul_result.view(*output_size) + core_attention_bias).to( + dtype=query_layer.dtype ) - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - # attention scores and attention mask [b, np, sq, sk] softmax_scale = self.layer_number if apply_qk_layer_scaling else None attention_probs = self.scale_mask_softmax( - attention_scores, attention_mask, attn_mask_type, softmax_scale + matmul_result, attention_mask, attn_mask_type, softmax_scale ) + # mask out the pad positions in softmax results, mostly for the rows (pad tokens from q) + # the columns (pad tokens from k) are already zeroed out during softmax + if "padding" in attn_mask_type: + attention_probs = attention_probs.masked_fill(attention_mask, 0) + # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. with self.attention_dropout_ctx(): @@ -6232,7 +6297,10 @@ def __init__( ) self.unfused_attention = UnfusedDotProductAttention( - softmax_scale, **attn_kwargs, layer_number=layer_number + softmax_scale, + attention_type=attention_type, + **attn_kwargs, + layer_number=layer_number, ) def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument @@ -6522,6 +6590,11 @@ def forward( if inference_params is not None: assert self.layer_number is not None, "Layer number must be set!" + # convert causal to causal_bottom_right in inference when KV-caching is in use + # so users can run with the same attn_mask_type for training and inference + if attn_mask_type in ["causal", "padding_causal"]: + attn_mask_type = attn_mask_type + "_bottom_right" + if qkv_format == "bshd": key_layer = key_layer.transpose(0, 1) value_layer = value_layer.transpose(0, 1) @@ -6628,7 +6701,6 @@ def forward( attention_mask is not None ), "Please provide attention_mask for padding!" if self.attention_type == "self": - assert max_seqlen_q == max_seqlen_kv cu_seqlens_q = get_cu_seqlens(attention_mask) cu_seqlens_kv = cu_seqlens_q else: diff --git a/transformer_engine/pytorch/softmax.py b/transformer_engine/pytorch/softmax.py index 3632d2f367..4fb8a28857 100644 --- a/transformer_engine/pytorch/softmax.py +++ b/transformer_engine/pytorch/softmax.py @@ -329,25 +329,22 @@ def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: return False # sk must be 16 ~ 16384 if sk % 8 != 0: return False # sk must be divisor of 8 - if self.attn_mask_type == "arbitrary": - return False # Custom masks not supported - + if sq == 1: + return False # sq must be > 1 if self.attn_mask_type == "causal" and sq != sk: return False # Fused causal kernel only support causal_bottom_right if ( sq % 4 == 0 # sq must be divisor of 4 and attn_batches % 4 == 0 # np * b must be divisor of 4 - and self.attn_mask_type != "arbitrary" # Custom masks not supported ): batch_per_block = self.get_batch_per_block(int(sk)) - - if self.attn_mask_type == "padding": + if "padding" in self.attn_mask_type or self.attn_mask_type == "arbitrary": if ( mask is not None and sq % batch_per_block == 0 - and mask.shape[-2] == sq - and mask.shape[-1] == sk + and mask.shape[0] in [1, b] + and mask.shape[1:] == (1, sq, sk) ): return True else: @@ -358,13 +355,21 @@ def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: def forward_fused_softmax( self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None ) -> torch.Tensor: - """Fused masked softmax kernel""" + """ + Fused masked softmax path. + attn_mask_type | module + ----------------------------------------------------------------------------------------- + no_mask | ScaledSoftmax + causal (self-attention), causal_bottom_right | ScaledAlignedCausalMaskedSoftmax + padding, padding_causal, padding_causal_bottom_right | ScaledMaskedSoftmax + arbitrary ([1, 1, sq, sk] or [b, 1, sq, sk]) | ScaledMaskedSoftmax + """ scale = 1.0 if scale is None else scale - if "causal" in self.attn_mask_type: + if self.attn_mask_type in ["causal", "causal_bottom_right"]: return ScaledAlignedCausalMaskedSoftmax.apply(inp, scale) - # input is 4D tensor (b, np, sq, sk) + # input is 4D tensor (1, 1, sq, sk) or (b, 1, sq, sk) if mask is not None and self.attn_mask_type != "no_mask": return ScaledMaskedSoftmax.apply(inp, mask, scale) return ScaledSoftmax.apply(inp, scale) @@ -379,13 +384,19 @@ def forward_torch_softmax( if scale is not None: inp = inp * scale - if "causal" in self.attn_mask_type: + if self.attn_mask_type in ["causal", "causal_bottom_right"]: seq_len_q, seq_len_k = inp.size(2), inp.size(3) if is_in_onnx_export_mode() and self.kvcache_max_seq > 0: assert self.kvcache_max_seq >= seq_len_k - mask = _get_onnx_export_causal_mask(seq_len_q, seq_len_k, self.onnx_causal_mask) + causal_mask = _get_onnx_export_causal_mask( + seq_len_q, seq_len_k, self.onnx_causal_mask + ) + else: + causal_mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k) + if mask is None: + mask = causal_mask else: - mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k) + mask = torch.logical_or(mask, causal_mask) mask_output = inp if mask is not None and self.attn_mask_type != "no_mask": diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 4cbee3d628..bd6e27594d 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -624,7 +624,7 @@ def forward( Whether to set output tensors to 0 or not before use. inference_params: InferenceParams, default = None Inference parameters that are passed to the main model in order - to efficienly calculate and store the context during inference. + to efficiently calculate and store the context during inference. """ if self_attn_mask_type is None: From 47caafb28fd3f8c99ada92f9d7504715d268aa91 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 21 Aug 2024 22:33:22 -0700 Subject: [PATCH 09/50] Re-add framework specific required dependencies for source build (#1124) * Re-add framework specific required dependencies for source build Signed-off-by: Kirthi Shankar Sivamani * fix build Signed-off-by: Kirthi Shankar Sivamani * Fix Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- setup.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/setup.py b/setup.py index e418cb95ff..6cee4690dc 100644 --- a/setup.py +++ b/setup.py @@ -89,6 +89,18 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: if not found_pybind11(): setup_reqs.append("pybind11") + # Framework-specific requirements + if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): + if "pytorch" in frameworks: + install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.5.8,!=2.0.9,!=2.1.0"]) + test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"]) + if "jax" in frameworks: + install_reqs.extend(["jax", "flax>=0.7.1"]) + test_reqs.extend(["numpy", "praxis"]) + if "paddle" in frameworks: + install_reqs.append("paddlepaddle-gpu") + test_reqs.append("numpy") + return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]] From a33537448c58a5474a246f57ea3ba00932ec472c Mon Sep 17 00:00:00 2001 From: NVJiangShao <91270701+StudyingShao@users.noreply.github.com> Date: Fri, 23 Aug 2024 00:43:59 +0800 Subject: [PATCH 10/50] [MoE][Common/PyTorch] Add permutation (#936) * Add permutation functions * Add permutation ops * Remove the dependency on cutlass * Move permutation.py out of module dir * Rewrite the unit test and enable skipping if FP8 is unavailable * Rename exposed C++ API and reorder its parameters + take NVTETensor as inputs * Use Float8Tensor for FP8 input * Move dtype to ctx --------- Signed-off-by: Jiang Shao Co-authored-by: Qi Zhang Co-authored-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> --- tests/pytorch/test_permutation.py | 515 ++++++++++++++++++ transformer_engine/common/CMakeLists.txt | 1 + .../common/gemm/cublaslt_gemm.cu | 2 +- .../include/transformer_engine/permutation.h | 21 + .../common/permutation/permutation.cu | 369 +++++++++++++ transformer_engine/pytorch/__init__.py | 1 + transformer_engine/pytorch/csrc/common.h | 1 + transformer_engine/pytorch/csrc/extensions.h | 20 + .../pytorch/csrc/extensions/permutation.cu | 170 ++++++ .../pytorch/csrc/extensions/pybind.cpp | 6 + transformer_engine/pytorch/permutation.py | 289 ++++++++++ 11 files changed, 1394 insertions(+), 1 deletion(-) create mode 100644 tests/pytorch/test_permutation.py create mode 100644 transformer_engine/common/include/transformer_engine/permutation.h create mode 100644 transformer_engine/common/permutation/permutation.cu create mode 100644 transformer_engine/pytorch/csrc/extensions/permutation.cu create mode 100644 transformer_engine/pytorch/permutation.py diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py new file mode 100644 index 0000000000..99bd706b45 --- /dev/null +++ b/tests/pytorch/test_permutation.py @@ -0,0 +1,515 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch +import pytest +from typing import Dict, List + +from transformer_engine.pytorch import moe_permute as te_permute, moe_unpermute as te_unpermute +from transformer_engine.pytorch.utils import is_bf16_compatible +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.float8_tensor import Float8Tensor +import transformer_engine_torch as tex + + +seed = 1234 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) + + +def pytorch_permute(tokens, indices, num_out_tokens: int = None): + """ + Permute the tokens based on the indices. Token with the same index will be grouped together. + The input indices shape is [tokens, top_k], it indicates which experts were selected by each token separately. + + Args: + tokens: torch.Tensor + The input token tensor. + indices: torch.Tensor + The token to expert indices tensor, should have a shape of [num_tokens] or [num_tokens, topk]. + num_out_tokens: int, optional + The effective output token count, when enabling the capacity factor, should equal the number of tokens not dropped. + By default, set to None, meaning no tokens are dropped. + + Returns: + torch.Tensor: + The permuted tensor. + torch.Tensor: + The sorted_indices corresponding permuted tensor. + """ + if indices.dim() == 1: + topk = 1 + else: + topk = indices.size(1) + flatten_indices = indices.view(-1) + sorted_indices = torch.argsort(flatten_indices, stable=True) + num_out_tokens = num_out_tokens if num_out_tokens is not None else flatten_indices.size(0) + + permuted_tokens = tokens.index_select(0, sorted_indices[:num_out_tokens] // topk) + return permuted_tokens, sorted_indices + + +def pytorch_unpermute( + permuted_tokens: torch.Tensor, + sorted_indices: torch.Tensor, + probs: torch.Tensor = None, +): + """ + Unpermute a tensor of permuted tokens based on sorted indices, and optionally merge the tokens with their + corresponding probabilities. + + Args: + permuted_tokens: torch.Tensor + The tensor of permuted tokens to be unpermuted. + sorted_indices: torch.Tensor + The tensor of sorted indices used to unpermute the tokens. + probs: torch.Tensor, optional + The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will + be merged with their respective probabilities. + + Returns: + torch.Tensor: + The unpermuted tokens, optionally merged with probabilities. + """ + + if probs is not None: + # Unpermute and merge the tokens with their probabilities + num_unpermuted_tokens = probs.numel() + topk = probs.size(1) + else: + # Unpermute the tokens without merge + num_unpermuted_tokens = sorted_indices.size(0) + topk = 1 + unpermuted_tokens = torch.zeros( + [num_unpermuted_tokens, permuted_tokens.shape[-1]], + dtype=permuted_tokens.dtype, + device=permuted_tokens.device, + ) + + unpermuted_tokens.index_copy_(0, sorted_indices[: permuted_tokens.size(0)], permuted_tokens) + unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1)) + if probs is not None: + unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1) + unpermuted_tokens = unpermuted_tokens.sum(dim=1) + return unpermuted_tokens + + +def dtype_tols(te_dtype: tex.DType) -> Dict[str, float]: + """Estimated tolerances for a datatype + + Based on tolerances for torch.testing.assert_close. + + """ + if te_dtype == tex.DType.kFloat32: + return dict(rtol=1.0e-6, atol=1.0e-6) + if te_dtype == tex.DType.kFloat16: + return dict(rtol=3.0e-3, atol=1.0e-5) + if te_dtype == tex.DType.kBFloat16: + return dict(rtol=2.0e-2, atol=1.0e-5) + if te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3: + return dict(rtol=2.0e-1, atol=1.0e-1) + raise ValueError(f"Unsuppored dtype ({te_dtype})") + + +def _test_permutation( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, + with_probs, + BENCHMARK=False, +): + if not with_probs and topK > 1: + pytest.skip("Only permutations with topK=1 and without probabilities are supported.") + + if topK > num_expert: + pytest.skip("topK should be smaller than the number of experts.") + + if num_out_tokens == None: + num_out_tokens = num_tokens * topK + + print( + f"token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}" + ) + + fp8 = False + # Convert TE dtypes to PyTorch dtypes + if te_dtype == tex.DType.kFloat32: + dtype = torch.float32 + elif te_dtype == tex.DType.kFloat16: + dtype = torch.float16 + elif te_dtype == tex.DType.kBFloat16: + dtype = torch.bfloat16 + elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3): + dtype = torch.uint8 + fp8 = True + else: + pytest.skip("Invalid dtype.") + + if fp8: + permute_fwd_input = torch.rand( + size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + permute_bwd_input = torch.rand( + size=(num_out_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + unpermute_bwd_input = torch.rand( + size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + + permute_fwd_input = Float8Tensor.to_float8( + permute_fwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + ) + permute_bwd_input = Float8Tensor.to_float8( + permute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + ) + unpermute_bwd_input = Float8Tensor.to_float8( + unpermute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + ) + + pytorch_permute_fwd_input = permute_fwd_input.from_float8(torch.float16) + pytorch_permute_bwd_input = permute_bwd_input.from_float8(torch.float16) + pytorch_unpermute_bwd_input = unpermute_bwd_input.from_float8(torch.float16) + else: + pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() + pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + + pytorch_permute_fwd_input.requires_grad_(True) + + if num_tokens > 0: + indices = torch.stack([torch.randperm(num_expert)[:topK] for _ in range(num_tokens)]) + else: + indices = torch.empty((num_tokens, topK)) + indices = indices.to(torch.int32).cuda() + + probs = None + if with_probs: + probs = torch.rand(num_tokens, topK).cuda() + row_sums = probs.sum(dim=1, keepdim=True) + probs = probs / row_sums + probs.requires_grad_(True) + + ################################################################################################################################### + # + # PyTorch Permutation + # + ################################################################################################################################### + pytorch_permute_output, sorted_indices = pytorch_permute( + pytorch_permute_fwd_input, indices, num_out_tokens + ) + pytorch_permute_output.backward(pytorch_permute_bwd_input, retain_graph=True) + + pytorch_unpermute_fwd_input = pytorch_permute_output.detach() + pytorch_unpermute_fwd_input.requires_grad_(True) + + pytorch_unpermute_output = pytorch_unpermute( + pytorch_unpermute_fwd_input, sorted_indices, probs=probs + ) + pytorch_unpermute_output.backward(pytorch_unpermute_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # TE Permutation + # + ################################################################################################################################### + te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach() + te_permute_fwd_input.requires_grad_(True) + te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach() + + te_permute_output, row_id_map = te_permute( + te_permute_fwd_input, te_dtype, indices, num_out_tokens + ) + te_permute_output.backward(te_permute_bwd_input, retain_graph=True) + + te_probs = None + if with_probs: + te_probs = probs.detach() + te_probs.requires_grad_(True) + te_unpermute_fwd_input = te_permute_output.detach() + te_unpermute_fwd_input.requires_grad_(True) + te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() + + te_unpermute_output = te_unpermute(te_unpermute_fwd_input, te_dtype, row_id_map, te_probs) + te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # Results Check + # + ################################################################################################################################### + tols = dtype_tols(te_dtype) + + if fp8: + te_permute_output_ = te_permute_output.from_float8(torch.float32) + te_permute_fwd_input_grad = te_permute_fwd_input.grad.from_float8(torch.float32) + te_unpermute_output_ = te_unpermute_output.from_float8(torch.float32) + te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.from_float8(torch.float32) + else: + te_permute_output_ = te_permute_output.float() + te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() + te_unpermute_output_ = te_unpermute_output.float() + te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float() + + torch.testing.assert_close( + pytorch_permute_output.float(), + te_permute_output_, + msg=f"Mismatch in te_permute fwd", + ) + torch.testing.assert_close( + pytorch_permute_fwd_input.grad.float(), + te_permute_fwd_input_grad, + msg=f"Mismatch in te_permute bwd", + **tols, + ) + torch.testing.assert_close( + pytorch_unpermute_output.float(), + te_unpermute_output_, + msg=f"Mismatch in te_unpermute fwd", + **tols, + ) + torch.testing.assert_close( + pytorch_unpermute_fwd_input.grad.float(), + te_unpermute_fwd_input_grad, + msg=f"Mismatch in te_unpermute bwd", + **tols, + ) + if with_probs: + torch.testing.assert_close( + probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols + ) + + if not pytorch_permute_fwd_input.numel(): + print("Empty pytorch_permute_fwd_input activation test passed.") + return + + ################################################################################################################################### + # + # Benchmark + # + ################################################################################################################################### + def backward_wrapper( + act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False + ): + # Set forward_input.grad to None to avoid grad accumulation. + if accumulate_grad == False: + for i in forward_input: + i.grad = None + return act.backward(backward_input, retain_graph=retain_graph) + + if BENCHMARK: + t1 = perf_test_cuda_kernel( + lambda: pytorch_permute(pytorch_permute_fwd_input, indices, num_out_tokens) + ) + t2 = perf_test_cuda_kernel( + lambda: te_permute(te_permute_fwd_input, te_dtype, indices, num_out_tokens) + ) + print(f"permute\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: backward_wrapper( + pytorch_permute_output, + pytorch_permute_bwd_input, + forward_input=[pytorch_permute_fwd_input], + retain_graph=True, + accumulate_grad=False, + ) + ) + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + te_permute_output, + te_permute_bwd_input, + forward_input=[te_permute_fwd_input], + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"permute\t\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: pytorch_unpermute(pytorch_unpermute_fwd_input, sorted_indices, probs=probs) + ) + t2 = perf_test_cuda_kernel( + lambda: te_unpermute(te_unpermute_fwd_input, te_dtype, row_id_map, te_probs) + ) + print(f"unpermute\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: backward_wrapper( + pytorch_unpermute_output, + pytorch_unpermute_bwd_input, + forward_input=( + [pytorch_unpermute_fwd_input, probs] + if with_probs + else [pytorch_unpermute_fwd_input] + ), + retain_graph=True, + accumulate_grad=False, + ) + ) + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + te_unpermute_output, + te_unpermute_bwd_input, + forward_input=( + [te_unpermute_fwd_input, te_probs] if with_probs else [te_unpermute_fwd_input] + ), + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") + + +def perf_test_cuda_kernel(cuda_kernel_fn): + if torch.cuda.is_available(): + # create CUDA event + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # warmup + for _ in range(50): + cuda_kernel_fn() + + start_event.record() + for _ in range(100): + cuda_kernel_fn() + end_event.record() + torch.cuda.synchronize() + + elapsed_time_ms = start_event.elapsed_time(end_event) + return elapsed_time_ms / 100 + else: + pytest.skip("CUDA is not available.") + + +# TE tensor dtypes +_te_dtypes: List[tex.DType] = [tex.DType.kFloat32, tex.DType.kFloat16] +if is_bf16_compatible(): + _te_dtypes.append(tex.DType.kBFloat16) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_tokens", [4096]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("num_out_tokens", [None, 2039]) +def test_permutation( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, +): + with_probs = True + BENCHMARK = False + + _test_permutation( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=BENCHMARK, + ) + + +# Only run FP8 tests on H100. +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) +@pytest.mark.parametrize("num_tokens", [2048]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("num_out_tokens", [None, 2039]) +def test_permutation_fp8( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, +): + with_probs = True + BENCHMARK = False + + _test_permutation( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=BENCHMARK, + ) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_tokens", [4096]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +def test_permutation_topk1_no_probs( + te_dtype, + num_tokens, + num_expert, + hidden_size, +): + topK = 1 + num_out_tokens = None + with_probs = False + BENCHMARK = False + + _test_permutation( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=BENCHMARK, + ) + + +def test_permutation_single_case(): + print("GPU:", torch.cuda.get_device_name(0)) + + # te_dtype = tex.DType.kFloat32 + # te_dtype = tex.DType.kFloat16 + # te_dtype = tex.DType.kBFloat16 + te_dtype = tex.DType.kFloat8E5M2 + # te_dtype = tex.DType.kFloat8E4M3 + + num_tokens = 10 + num_expert = 4 + hidden_size = 16 + topK = 2 + num_out_tokens = num_tokens * topK - 1 + with_probs = True + Benchmark = True + + _test_permutation( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=Benchmark, + ) + + +if __name__ == "__main__": + test_permutation_single_case() diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index a4497751f4..06bfec49b4 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -62,6 +62,7 @@ list(APPEND transformer_engine_SOURCES layer_norm/ln_api.cpp layer_norm/ln_bwd_semi_cuda_kernel.cu layer_norm/ln_fwd_cuda_kernel.cu + permutation/permutation.cu rmsnorm/rmsnorm_api.cpp rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu rmsnorm/rmsnorm_fwd_cuda_kernel.cu diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 8667b64e65..593ec086d7 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -255,7 +255,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, "Unable to find suitable cuBLAS GEMM algorithm"); NVTE_CHECK_CUBLAS(status); - if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms"); + if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms"); // D = alpha * (A * B) + beta * C NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, diff --git a/transformer_engine/common/include/transformer_engine/permutation.h b/transformer_engine/common/include/transformer_engine/permutation.h new file mode 100644 index 0000000000..c6263bf87e --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/permutation.h @@ -0,0 +1,21 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_PERMUTATION_H_ +#define TRANSFORMER_ENGINE_PERMUTATION_H_ + +#include "transformer_engine.h" + +void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor sorted_row_id, + NVTETensor row_id_map, const NVTETensor prob, NVTETensor prob_grad, + const NVTETensor input_fwd, const int num_rows, const int topK, + const int num_cols, const int num_out_tokens, cudaStream_t stream = nullptr); + +void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id_map, + const NVTETensor prob, const int num_rows, const int topK, const int num_cols, + cudaStream_t stream = nullptr); + +#endif // TRANSFORMER_ENGINE_PERMUTATION_H_ diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu new file mode 100644 index 0000000000..2b894fbfdc --- /dev/null +++ b/transformer_engine/common/permutation/permutation.cu @@ -0,0 +1,369 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "../common.h" + +static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id_map, + const int num_rows, const int topK, + const int num_out_tokens) { + // Each block corresponds to one source token + // row_id_map[topK][num_rows] + const int bid = blockIdx.x; + const int tid = threadIdx.x; + const int idx = bid * blockDim.x + tid; + + if (idx >= num_rows * topK) return; + + int source_row = sorted_row_id[idx]; + int source_token_id = source_row / topK; + int source_topK_id = source_row % topK; + + if (idx >= num_out_tokens) { + // Set the indices of dropped tokens to -1 + row_id_map[source_topK_id * num_rows + source_token_id] = -1; + } else { + // Create a row id map for subsequent unpermute operation + row_id_map[source_topK_id * num_rows + source_token_id] = idx; + } +} + +template +__global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const int *row_id_map, + const float *prob, const int num_rows, const int topK, + const int num_cols) { + extern __shared__ int8_t s_mem[]; + TCompute *s_prob = reinterpret_cast(s_mem); + + // Each block corresponds to one dest token + const int source_token = blockIdx.x; + const int tid = threadIdx.x; + + if (hasProb) { + for (int i = tid; i < topK; i += blockDim.x * blockDim.y) { + // Load all the topK probs related to the source row into smem + s_prob[i] = TCompute(prob[source_token * topK + i]); + } + __syncthreads(); + } + + // Register buffers for vector type (float4) memory access + float4 frag_load_store; + T *frag_load_store_ptr = reinterpret_cast(&frag_load_store); + + // Number of elemments in frag_load_store + static constexpr int kElementsPerAccess = 16 / sizeof(T); + + // Traverse along the hidden dimention + for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) { + TCompute frag_elem[kElementsPerAccess]; + TCompute frag_sum[kElementsPerAccess]; + + int source_row = row_id_map[source_token]; + + // source_row == -1 represents a dropped token + if (source_row != -1) { + const T *source_row_ptr = input + source_row * num_cols; + + frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i)); + + for (int e = 0; e < kElementsPerAccess; e++) { + frag_sum[e] = TCompute(frag_load_store_ptr[e]); + } + + if (hasProb) { + for (int e = 0; e < kElementsPerAccess; e++) { + frag_sum[e] = frag_sum[e] * s_prob[0]; + } + } + } else { + for (int e = 0; e < kElementsPerAccess; e++) { + frag_sum[e] = TCompute(0.0f); + } + } + + for (int k = 1; k < topK; k++) { + source_row = row_id_map[k * num_rows + source_token]; + + if (source_row == -1) continue; + + const T *source_row_ptr = input + source_row * num_cols; + + frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i)); + + for (int e = 0; e < kElementsPerAccess; e++) { + frag_elem[e] = TCompute(frag_load_store_ptr[e]); + } + + if (hasProb) { + for (int e = 0; e < kElementsPerAccess; e++) { + frag_elem[e] = frag_elem[e] * s_prob[k]; + } + } + + for (int e = 0; e < kElementsPerAccess; e++) { + frag_sum[e] = frag_sum[e] + frag_elem[e]; + } + } + + T *dest_row_ptr = unpermuted_output + source_token * num_cols; + + for (int e = 0; e < kElementsPerAccess; e++) { + if constexpr ((std::is_same_v || std::is_same_v) && + (!hasProb)) { + frag_sum[e] = frag_sum[e] / TCompute(topK); + } + frag_load_store_ptr[e] = T(frag_sum[e]); + } + + *reinterpret_cast(dest_row_ptr + i) = frag_load_store; + } +} + +template +__global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *act_grad, + const float *prob, float *prob_grad, const int *row_id_map, + const int num_rows, const int topK, const int num_cols) { + extern __shared__ int8_t s_mem[]; + TCompute *s_prob = reinterpret_cast(s_mem); + + // Each block corresponds to one source token + const int source_token = blockIdx.x; + const int tid = threadIdx.x; + + if (hasProb) { + for (int i = tid; i < topK; i += blockDim.x) { + // Load all the topK probs related to the source row into smem + s_prob[i] = TCompute(prob[source_token * topK + i]); + } + __syncthreads(); + } + + // Accumulators for the calculation of prob_grad + float accum[topKTile] = {0.0f}; + + // Register buffers for vector type (float4) memory access + float4 frag_load_store; + T *frag_load_store_ptr = reinterpret_cast(&frag_load_store); + + // Number of elemments in frag_load_store + static constexpr int kElementsPerAccess = 16 / sizeof(T); + + // The starting address of each source row + const T *source_row_ptr = input_bwd + source_token * num_cols; + + // Traverse along the hidden dimention + for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) { + TCompute frag_src[kElementsPerAccess]; + + frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i)); + + for (int e = 0; e < kElementsPerAccess; e++) frag_src[e] = TCompute(frag_load_store_ptr[e]); + + int index = source_token; + + // Process each row in the corresponding topK rows + for (int k = 0; k < topKTile; k++) { + if (k == topK) break; + + int dest_row = row_id_map[index]; + index += num_rows; + + if (dest_row != -1) { + if (hasProb) { + // Calculate act_grad in unpermute bwd + for (int e = 0; e < kElementsPerAccess; e++) + frag_load_store_ptr[e] = T(frag_src[e] * s_prob[k]); + } else { + // permute fwd + for (int e = 0; e < kElementsPerAccess; e++) frag_load_store_ptr[e] = T(frag_src[e]); + } + + T *dest_row_ptr = act_grad + dest_row * num_cols; + *reinterpret_cast(dest_row_ptr + i) = frag_load_store; + + if (hasProb) { + // Inner product calculation for prob_grad in unpermute bwd + const T *input_fwd_ptr = input_fwd + dest_row * num_cols; + + frag_load_store = __ldlu(reinterpret_cast(input_fwd_ptr + i)); + + TCompute frag_input_fwd[kElementsPerAccess]; + for (int e = 0; e < kElementsPerAccess; e++) + frag_input_fwd[e] = TCompute(frag_load_store_ptr[e]); + + for (int e = 0; e < kElementsPerAccess; e++) { + accum[k] += static_cast(frag_src[e] * frag_input_fwd[e]); + } + } + } + } + } + + if (hasProb) { + for (int k = 0; k < topKTile; k++) { + if (k == topK) break; + // Warp-level reduction + for (int mask = 16; mask > 0; mask /= 2) { + accum[k] = accum[k] + __shfl_xor_sync(0xffffffff, accum[k], mask, 32); + } + } + + if (tid == 0) { + for (int k = 0; k < topKTile; k++) { + if (k == topK) break; + prob_grad[source_token * topK + k] = accum[k]; + } + } + } +} + +template +void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, int *row_id_map, + const float *prob, float *prob_grad, const T *input_fwd, + const int num_rows, const int topK, const int num_cols, + const int num_out_tokens, cudaStream_t stream) { + using TCompute = typename std::conditional<(std::is_same::value || + std::is_same::value), + half, T>::type; + + static constexpr int kElementsPerAccess = 16 / sizeof(T); + + if (input_fwd == nullptr) { + // moe_permute_fwd + + int threads = 64; + int blocks = (num_rows * topK + threads - 1) / threads; + + moe_permute_row_map<<>>(sorted_row_id, row_id_map, num_rows, topK, + num_out_tokens); + + blocks = num_rows; + threads = std::min(num_cols / kElementsPerAccess, 1024); + moe_permute_kernel<<>>( + input, nullptr, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); + } else { + // moe_unpermute_bwd + + int threads = 32; + int blocks = num_rows; + + if (prob == nullptr) { + // moe_unpermute_bwd without probs + + moe_permute_kernel<<>>( + input, input_fwd, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); + } else { + // moe_unpermute_bwd with probs + + size_t smem_bytes = topK * sizeof(TCompute); + + if (topK <= 8) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else if (topK <= 16) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else if (topK <= 32) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else if (topK <= 64) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else if (topK <= 128) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else { + NVTE_ERROR("topK cannot exceed 128."); + } + } + } +} + +template +void nvte_unpermute_launcher(const T *input, T *output, int *row_id_map, const float *prob, + const int num_rows, const int topK, const int num_cols, + cudaStream_t stream) { + using TCompute = typename std::conditional<(std::is_same::value || + std::is_same::value), + half, T>::type; + + static constexpr int kElementsPerAccess = 16 / sizeof(T); + + int blocks = num_rows; + int threads = std::min(num_cols / kElementsPerAccess, 1024); + size_t smem_bytes = topK * sizeof(TCompute); + + if (prob == nullptr) { + // moe_permute_bwd + // moe_unpermute_fwd without probs + + moe_unpermute_kernel<<>>( + input, output, row_id_map, nullptr, num_rows, topK, num_cols); + } else { + // moe_unpermute_fwd with probs + + moe_unpermute_kernel<<>>( + input, output, row_id_map, prob, num_rows, topK, num_cols); + } +} + +void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor sorted_row_id, + NVTETensor row_id_map, const NVTETensor prob, NVTETensor prob_grad, + const NVTETensor input_fwd, const int num_rows, const int topK, + const int num_cols, const int num_out_tokens, cudaStream_t stream) { + NVTE_API_CALL(nvte_permute); + + const transformer_engine::Tensor *input_cu = + reinterpret_cast(input); + const transformer_engine::Tensor *output_cu = + reinterpret_cast(output); + const transformer_engine::Tensor *sorted_row_id_cu = + reinterpret_cast(sorted_row_id); + const transformer_engine::Tensor *row_id_map_cu = + reinterpret_cast(row_id_map); + const transformer_engine::Tensor *prob_cu = + reinterpret_cast(prob); + const transformer_engine::Tensor *prob_grad_cu = + reinterpret_cast(prob_grad); + const transformer_engine::Tensor *input_fwd_cu = + reinterpret_cast(input_fwd); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + input_cu->data.dtype, T, + nvte_permute_launcher(reinterpret_cast(input_cu->data.dptr), + reinterpret_cast(output_cu->data.dptr), + reinterpret_cast(sorted_row_id_cu->data.dptr), + reinterpret_cast(row_id_map_cu->data.dptr), + reinterpret_cast(prob_cu->data.dptr), + reinterpret_cast(prob_grad_cu->data.dptr), + reinterpret_cast(input_fwd_cu->data.dptr), num_rows, topK, + num_cols, num_out_tokens, stream);); +} + +void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id_map, + const NVTETensor prob, const int num_rows, const int topK, const int num_cols, + cudaStream_t stream) { + NVTE_API_CALL(nvte_unpermute); + + const transformer_engine::Tensor *input_cu = + reinterpret_cast(input); + const transformer_engine::Tensor *output_cu = + reinterpret_cast(output); + const transformer_engine::Tensor *row_id_map_cu = + reinterpret_cast(row_id_map); + const transformer_engine::Tensor *prob_cu = + reinterpret_cast(prob); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + input_cu->data.dtype, T, + nvte_unpermute_launcher(reinterpret_cast(input_cu->data.dptr), + reinterpret_cast(output_cu->data.dptr), + reinterpret_cast(row_id_map_cu->data.dptr), + reinterpret_cast(prob_cu->data.dptr), num_rows, topK, + num_cols, stream);); +} diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 20b6f79da6..1c755491b0 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -44,6 +44,7 @@ def _load_library(): from transformer_engine.pytorch.attention import InferenceParams from transformer_engine.pytorch.attention import MultiheadAttention from transformer_engine.pytorch.transformer import TransformerLayer +from transformer_engine.pytorch.permutation import moe_permute, moe_unpermute from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.fp8 import fp8_model_init from transformer_engine.pytorch.graph import make_graphed_callables diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index aac693a430..7fb9953f94 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 05e4e97112..1a6f5f157e 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -10,6 +10,26 @@ #include "common.h" #include "common/common.h" +/*************************************************************************************************** + * Permutation + **************************************************************************************************/ + +std::tuple> moe_permute_fwd( + at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices, + int64_t num_out_tokens, std::vector workspace, int64_t max_expanded_token_num); + +at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, + int64_t topK); + +at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, + int64_t topK); + +std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd, + const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob); + /*************************************************************************************************** * Attention **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cu b/transformer_engine/pytorch/csrc/extensions/permutation.cu new file mode 100644 index 0000000000..0c9bed45e0 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cu @@ -0,0 +1,170 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "extensions.h" + +std::tuple> moe_permute_fwd( + at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices, + int64_t num_out_tokens, std::vector workspace, int64_t max_expanded_token_num) { + const int num_tokens = input.size(0); + int num_cols = input.size(1); + const int topK = indices.size(1); + + // Initialize the workspace on the first run + if (workspace.empty()) { + auto options = + torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false); + + at::Tensor sorted_indices = torch::empty(max_expanded_token_num, options); + at::Tensor row_id = torch::range(0, max_expanded_token_num - 1, 1, options); + at::Tensor sorted_row_id = + torch::empty(max_expanded_token_num, + torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); + + size_t temp_storage_bytes = 0; + int *temp_ptr = nullptr; + cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_ptr, temp_ptr, temp_ptr, + temp_ptr, max_expanded_token_num); + at::Tensor temp_storage = torch::empty( + temp_storage_bytes, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); + + workspace.push_back(sorted_indices); + workspace.push_back(row_id); + workspace.push_back(sorted_row_id); + workspace.push_back(temp_storage); + } + + int *indices_ptr = reinterpret_cast(getDataPtr(indices, 0)); + int *sorted_indices_ptr = reinterpret_cast(getDataPtr(workspace[0], 0)); + int *row_id_ptr = reinterpret_cast(getDataPtr(workspace[1], 0)); + int *sorted_row_id_ptr = reinterpret_cast(getDataPtr(workspace[2], 0)); + + void *d_temp_storage = getDataPtr(workspace[3], 0); + size_t temp_storage_bytes = std::numeric_limits::max(); + + cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, indices_ptr, + sorted_indices_ptr, row_id_ptr, sorted_row_id_ptr, + num_tokens * topK); + + // Activations type + at::ScalarType _st; + if (dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2) + _st = at::ScalarType::Byte; + else + _st = input.scalar_type(); + + // Output buffer alloc + num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK; + at::Tensor permuted_output = torch::empty( + {num_out_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); + at::Tensor row_id_map = torch::empty( + {num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + auto input_cu = makeTransformerEngineTensor( + input.data_ptr(), {static_cast(input.size(0)), static_cast(num_cols)}, dtype); + auto permuted_output_cu = makeTransformerEngineTensor( + permuted_output.data_ptr(), + {static_cast(permuted_output.size(0)), static_cast(num_cols)}, dtype); + auto sorted_row_id_cu = + makeTransformerEngineTensor(sorted_row_id_ptr, {static_cast(num_tokens * topK)}, + transformer_engine::DType::kInt32); + auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); + + nvte_permute(input_cu.data(), permuted_output_cu.data(), sorted_row_id_cu.data(), + row_id_map_cu.data(), transformer_engine::TensorWrapper().data(), + transformer_engine::TensorWrapper().data(), + transformer_engine::TensorWrapper().data(), num_tokens, topK, num_cols, + num_out_tokens, stream); + + return std::make_tuple(permuted_output, row_id_map, workspace); +} + +at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, + int64_t topK) { + return moe_unpermute_fwd(input, dtype, row_id_map, prob, num_tokens, topK); +} + +at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, + int64_t topK) { + int num_cols = input.size(1); + + // Activations type + at::ScalarType _st; + if (dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2) + _st = at::ScalarType::Byte; + else + _st = input.scalar_type(); + + // Output buffer alloc + at::Tensor unpermuted_output = torch::empty( + {num_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + auto input_cu = makeTransformerEngineTensor( + input.data_ptr(), {static_cast(input.size(0)), static_cast(num_cols)}, dtype); + auto unpermuted_output_cu = makeTransformerEngineTensor( + unpermuted_output.data_ptr(), + {static_cast(unpermuted_output.size(0)), static_cast(num_cols)}, dtype); + auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); + auto prob_cu = makeTransformerEngineTensor(prob); + + nvte_unpermute(input_cu.data(), unpermuted_output_cu.data(), row_id_map_cu.data(), prob_cu.data(), + num_tokens, topK, num_cols, stream); + + return unpermuted_output; +} + +std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd, + const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob) { + const int topK = (prob.numel() > 0) ? prob.size(1) : 1; + const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0); + int num_cols = input_bwd.size(1); + + // Activations type + at::ScalarType _st; + if (dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2) + _st = at::ScalarType::Byte; + else + _st = input_bwd.scalar_type(); + + // Output buffer alloc + at::Tensor act_grad = torch::empty({input_fwd.size(0), num_cols}, + torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); + at::Tensor prob_grad = torch::empty( + {num_tokens, topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + auto input_bwd_cu = makeTransformerEngineTensor( + input_bwd.data_ptr(), {static_cast(input_bwd.size(0)), static_cast(num_cols)}, + dtype); + auto act_grad_cu = makeTransformerEngineTensor( + act_grad.data_ptr(), {static_cast(act_grad.size(0)), static_cast(num_cols)}, + dtype); + auto input_fwd_cu = makeTransformerEngineTensor( + input_fwd.data_ptr(), {static_cast(input_fwd.size(0)), static_cast(num_cols)}, + dtype); + auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); + auto prob_cu = makeTransformerEngineTensor(prob); + auto prob_grad_cu = makeTransformerEngineTensor(prob_grad); + + nvte_permute(input_bwd_cu.data(), act_grad_cu.data(), transformer_engine::TensorWrapper().data(), + row_id_map_cu.data(), prob_cu.data(), prob_grad_cu.data(), input_fwd_cu.data(), + num_tokens, topK, num_cols, 0, stream); + + return std::make_tuple(act_grad, prob_grad); +} diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 11b47ccdec..f903a1c35b 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -10,6 +10,12 @@ #include "../extensions.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // Permutation functions + m.def("moe_permute_fwd", moe_permute_fwd); + m.def("moe_permute_bwd", moe_permute_bwd); + m.def("moe_unpermute_fwd", moe_unpermute_fwd); + m.def("moe_unpermute_bwd", moe_unpermute_bwd); + // Softmax functions m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD", py::call_guard()); diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py new file mode 100644 index 0000000000..0c098830a9 --- /dev/null +++ b/transformer_engine/pytorch/permutation.py @@ -0,0 +1,289 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Linear API""" +import warnings +from typing import Tuple +import torch + +import transformer_engine_torch as tex +from transformer_engine.pytorch.float8_tensor import Float8Tensor + + +__all__ = [ + "moe_permute", + "moe_unpermute", +] + + +class _moe_permute(torch.autograd.Function): + """functional Permute""" + + workspace = None + max_expanded_token_num = 0 + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + dtype: tex.DType, + indices: torch.Tensor, + num_out_tokens: int, + max_token_num: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Empty input check + if not inp.numel(): + return inp, None + + # Device check + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert indices.is_cuda, "TransformerEngine needs CUDA." + # Shape check + assert inp.size(0) == indices.size(0), "Permute not possible" + + # Data type check + fp8 = False + if dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: + fp8 = True + if fp8: + assert isinstance( + inp, Float8Tensor + ), "Input must be in Float8Tensor type for FP8 moe_permute." + fp8_dtype = inp._fp8_dtype + fp8_scale_inv = inp._scale_inv + inp = inp._data + if indices.dtype != torch.int32: + warnings.warn( + f"The data type of the input `indices` of Permute is {indices.dtype}! " + "The recommended type is torch.int32." + ) + indices = indices.to(torch.int32) + + topK = indices.size(1) + + input_max_expanded_token_num = max(max_token_num, inp.size(0)) * topK + if _moe_permute.max_expanded_token_num < input_max_expanded_token_num: + _moe_permute.max_expanded_token_num = input_max_expanded_token_num + _moe_permute.workspace = [] + + permuted_act, row_id_map, _moe_permute.workspace = tex.moe_permute_fwd( + inp, + dtype, + indices, + num_out_tokens, + _moe_permute.workspace, + _moe_permute.max_expanded_token_num, + ) + + if fp8: + permuted_act = Float8Tensor( + data=permuted_act, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv + ) + + ctx.row_id_map = row_id_map + ctx.num_tokens = indices.size(0) + ctx.topK = indices.size(1) + ctx.dtype = dtype + ctx.fp8 = fp8 + return permuted_act, row_id_map + + @staticmethod + def backward( + ctx, + permuted_act_grad: torch.Tensor, + _, + ) -> Tuple[torch.Tensor, ...]: + # Empty input check + if not permuted_act_grad.numel(): + return permuted_act_grad, None, None, None + + if not permuted_act_grad.is_contiguous(): + permuted_act_grad = permuted_act_grad.contiguous() + + fp8 = ctx.fp8 + if fp8: + assert isinstance( + permuted_act_grad, Float8Tensor + ), "Grad of the output must be in Float8Tensor type for FP8 moe_permute." + fp8_dtype = permuted_act_grad._fp8_dtype + fp8_scale_inv = permuted_act_grad._scale_inv + permuted_act_grad = permuted_act_grad._data + + row_id_map = ctx.row_id_map + num_tokens = ctx.num_tokens + topK = ctx.topK + + act_grad = None + if ctx.needs_input_grad[0]: + act_grad = tex.moe_permute_bwd( + permuted_act_grad, ctx.dtype, row_id_map, torch.empty(0), num_tokens, topK + ) + if fp8: + act_grad = Float8Tensor( + data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv * topK + ) + + return act_grad, None, None, None, None + + +class _moe_unpermute(torch.autograd.Function): + """functional Unpermute""" + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + dtype: tex.DType, + row_id_map: torch.Tensor, + probs: torch.Tensor, + ) -> torch.Tensor: + # Empty input check + if not inp.numel(): + ctx.probs = probs + return inp + + # None probs check + if probs is not None: + assert probs.is_cuda, "TransformerEngine needs CUDA." + + if probs.dtype != torch.float32: + warnings.warn( + f"The data type of the input `probs` of Unpermute is {probs.dtype}! " + "The recommended type is torch.float32." + ) + probs = probs.to(torch.float32) + + num_tokens = probs.size(0) + topK = probs.size(1) + else: + num_tokens = row_id_map.size(0) + topK = 1 + probs = torch.empty(0) + + # Device check + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert row_id_map.is_cuda, "TransformerEngine needs CUDA." + + # Data type check + fp8 = False + if dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: + fp8 = True + if fp8: + assert isinstance( + inp, Float8Tensor + ), "Input must be in Float8Tensor type for FP8 moe_unpermute." + fp8_dtype = inp._fp8_dtype + fp8_scale_inv = inp._scale_inv + inp = inp._data + if row_id_map.dtype != torch.int32: + warnings.warn( + f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! " + "The recommended type is torch.int32." + ) + row_id_map = row_id_map.to(torch.int32) + + unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK) + + if fp8: + unpermuted_output = Float8Tensor( + data=unpermuted_output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv + ) + + ctx.dtype = dtype + ctx.save_for_backward(inp, row_id_map, probs) + ctx.fp8 = fp8 + return unpermuted_output + + @staticmethod + def backward( + ctx, + unpermuted_act_grad: torch.Tensor, + ) -> Tuple[torch.Tensor, None, torch.Tensor]: + # Empty input check + if not unpermuted_act_grad.numel(): + return unpermuted_act_grad, None, ctx.probs + + if not unpermuted_act_grad.is_contiguous(): + unpermuted_act_grad = unpermuted_act_grad.contiguous() + + fp8 = ctx.fp8 + if fp8: + assert isinstance( + unpermuted_act_grad, Float8Tensor + ), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute." + fp8_dtype = unpermuted_act_grad._fp8_dtype + fp8_scale_inv = unpermuted_act_grad._scale_inv + unpermuted_act_grad = unpermuted_act_grad._data + + inp, row_id_map, probs = ctx.saved_tensors + + act_grad = None + if ctx.needs_input_grad[0]: + act_grad, prob_grad = tex.moe_unpermute_bwd( + unpermuted_act_grad, inp, ctx.dtype, row_id_map, probs + ) + if fp8: + act_grad = Float8Tensor( + data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv + ) + if not ctx.needs_input_grad[3]: + prob_grad = None + + return act_grad, None, None, prob_grad + + +def moe_permute( + inp: torch.Tensor, + dtype: tex.DType, + indices: torch.Tensor, + num_out_tokens: int = -1, + max_token_num: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Permute the tokens based on the indices. Token with the same index will be grouped together. + + Parameters + ---------- + inp: torch.Tensor + Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. + dtype: tex.DType + Data type of the input tensor. + indices: torch.Tensor + The token to expert indices tensor of shape [num_tokens, topK] and dtype 'int32'. + num_out_tokens: int, default = -1 + The effective output token count, representing the number of tokens not dropped. + By default, set to '-1', meaning no tokens are dropped. + max_token_num: int, default = -1 + The maximum number of tokens, used for workspace allocation. + By default, set to '-1', meaning the calculation of the size of workspace is + automatically taken over by the operator. + """ + return _moe_permute.apply(inp, dtype, indices, num_out_tokens, max_token_num) + + +def moe_unpermute( + inp: torch.Tensor, + dtype: tex.DType, + row_id_map: torch.Tensor, + probs: torch.Tensor = None, +) -> torch.Tensor: + """ + Unpermute a tensor with permuted tokens, and optionally merge the tokens with their + corresponding probabilities. + + Parameters + ---------- + inp: torch.Tensor + Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted. + dtype: tex.DType + Data type of the input tensor. + row_id_map: torch.Tensor + The tensor of a mapping table for sorted indices used to unpermute the tokens, + which is the second output tensor of `Permute`. + probs: torch.Tensor + The tensor of probabilities corresponding to the permuted tokens. If provided, + the unpermuted tokens will be merged with their respective probabilities. + By default, set to an empty tensor, which means that the tokens are directly merged by accumulation. + """ + return _moe_unpermute.apply(inp, dtype, row_id_map, probs) From 309c6d499385c1181e3af7a3883c202b74248a37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Bastien?= Date: Fri, 23 Aug 2024 10:30:14 -0400 Subject: [PATCH 11/50] Jax example cleanup and replace pjit with jit. (#1107) * Use jit instead of pjit --------- Signed-off-by: Frederic Bastien Co-authored-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> --- examples/jax/encoder/README.md | 10 +-- .../encoder/test_model_parallel_encoder.py | 85 ++++++++++-------- examples/jax/encoder/test_multigpu_encoder.py | 87 ++++++++++-------- .../encoder/test_multiprocessing_encoder.py | 89 +++++++++++-------- 4 files changed, 153 insertions(+), 118 deletions(-) diff --git a/examples/jax/encoder/README.md b/examples/jax/encoder/README.md index f8e150eac6..575f7be6e0 100644 --- a/examples/jax/encoder/README.md +++ b/examples/jax/encoder/README.md @@ -1,6 +1,6 @@ # Basic Transformer Encoder Example with Optional FP8 # -This example uses Transformer Encoder to demonstrate the Transformer Engine usage. And more focus on scaling up training on multiple GPUs. Highly recommend studying the [MNIST example of the Transformer Engine](/examples/jax/mnist) before reading this example. The Transformer Engine is built on top of [Flax](https://github.com/google/flax). Thus, examples use `pjit` to set up multiple GPU training. The basic pjit usage can be referred to [Scale up Flax Modules on multiple devices with pjit](https://flax.readthedocs.io/en/latest/guides/flax_on_pjit.html). +This example uses Transformer Encoder to demonstrate the Transformer Engine usage. And more focus on scaling up training on multiple GPUs. Highly recommend studying the [MNIST example of the Transformer Engine](/examples/jax/mnist) before reading this example. The Transformer Engine is built on top of [Flax](https://github.com/google/flax). Thus, examples use `jit` `in `in_shardings` and `out_shardings` parameters to set up multiple GPU training. The basic parallel jit usage can be referred to [Scale up Flax Modules on multiple devices](https://flax.readthedocs.io/en/latest/guides/flax_on_pjit.html). ## Single GPU ## @@ -31,11 +31,11 @@ python test_single_gpu_encoder.py --use-fp8 4. The next is to create sharding rules, mapping the device axis to the logical axis. The `te.extend_logical_axis_rules` under fp8_autocast will return a list of pairs of the mapping, such as `(('batch', 'data'), ...)`. The first is the logical axis and second is the device axis. -5. Refer structure of `abs_var_collect['params']` and `abs_var_collect['params_axes']` to set up `PartitionSpec` for pjit. All logical axes should be replaced by device axes. If the value of PartitionSpec is None, that means no sharding, broadcasting the data to every device. Note that the `params_axes` attribute is provided by Transformer Engine. The Flax's module doesn't have it, such as `nn.Embed`. For nn.Embed, assigning an empty PartitionSpec is fine because each device has its own embedding layer in DP mode. The `get_params_pspec` routine is used for this purpose. Because each device has a complete model in DP mode, all values of PartitionSpec in params_pspec should be None. This will be different in the model parallelism example. +5. Refer structure of `abs_var_collect['params']` and `abs_var_collect['params_axes']` to set up `PartitionSpec` for parallel jit. All logical axes should be replaced by device axes. If the value of PartitionSpec is None, that means no sharding, broadcasting the data to every device. Note that the `params_axes` attribute is provided by Transformer Engine. The Flax's module doesn't have it, such as `nn.Embed`. For nn.Embed, assigning an empty PartitionSpec is fine because each device has its own embedding layer in DP mode. The `get_params_pspec` routine is used for this purpose. Because each device has a complete model in DP mode, all values of PartitionSpec in params_pspec should be None. This will be different in the model parallelism example. -6. Fill in `params_pspec` and `encoder.init` to pjit to get a compiled function, `pjit_encoder_init `, and use it to initialize the model, so JAX now can know how to do the sharding. +6. Fill in `params_sharding` and `encoder.init` to jit to get a compiled function, `jit_encoder_init `, and use it to initialize the model, so JAX now can know how to do the sharding. -7. The `train_step` and `eval_step` also need to be compiled by pjit. Thus, every input and output argument has to be set up `PartitionSpec` if the argument contains a tensor. For instance, the `input_pspec` is `PartitionSpec('data', None)` because the input shape is (batch size, sequence length). Then, the rest of the workflow is similar to the previous example. +7. The `train_step` and `eval_step` also need to be compiled by jit. Thus, every input and output argument has to be set up `PartitionSpec` if the argument contains a tensor. For instance, the `input_pspec` is `PartitionSpec('data', None)` because the input shape is (batch size, sequence length). Then, the rest of the workflow is similar to the previous example. 8. Use `CUDA_VISIBLE_DEVICES` to control the number of GPUs used. For example, if the system has 8 GPUs but only 4 GPUs need to be used, then: ```sh @@ -84,7 +84,7 @@ python test_model_parallel_encoder.py --use-fp8 1. This example inherits previous model parallelism example, but uses multiprocessing instead of single-program multiple-data (SPMD). It uses 1 GPU per process. -2. The benefit of multiprocessing is to setup hardware affinity for GPUs, such as NUMA binding. It may help improve performance and stability. Please refer to [Best Practices When Benchmarking CUDA Applications](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2019-s9956/) for more details. +2. There is two main benefits of multiprocessing: support multi-node and to setup hardware affinity for GPUs, such as NUMA binding. Affinity may help improve performance and stability. Please refer to [Best Practices When Benchmarking CUDA Applications](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2019-s9956/) for more details. 3. The quick way to check system topology is to use `nvidia-smi`, for example: ```sh diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 646d6e0a12..25d744887e 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -17,7 +17,7 @@ from flax.linen import partitioning as nn_partitioning from flax.training import train_state from jax.experimental import mesh_utils -from jax.experimental.pjit import pjit +from jax.sharding import PartitionSpec, NamedSharding import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax @@ -223,32 +223,36 @@ def check_fp8(state, var_collect, inputs, masks, labels): ) -def get_params_pspec(sharding_rules, abs_var_collect): - """Refer params to create params partition spec""" - rules_dict = {} - for key, value in sharding_rules: - rules_dict[key] = value +def get_params_sharding(sharding_rules, abs_var_collect, mesh): + """Refer params to create params sharding""" + rules_dict = dict(sharding_rules) def to_device_axis(logical_axis): partitions = [rules_dict[key] for key in logical_axis] - return jax.sharding.PartitionSpec(*partitions) + return NamedSharding(mesh, PartitionSpec(*partitions)) params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {}) - params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes)) - params_axes_pspec = flax.core.unfreeze(params_axes_pspec) - params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY]) - params_pspec = {**params_pspec, **params_axes_pspec} - return params_pspec + params_axes_sharding = jax.tree_util.tree_map( + to_device_axis, nn_partitioning.get_axis_names(params_axes) + ) + params_axes_sharding = flax.core.unfreeze(params_axes_sharding) + params_sharding = jax.tree_util.tree_map( + lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY] + ) + params_sharding = {**params_sharding, **params_axes_sharding} + return params_sharding -def get_state_pspec(state, params_pspec): - """Refer params_pspec to create state partition spec""" +def get_state_sharding(state, params_sharding): + """Refer params_sharding to create state sharding""" def replace_params(x): - return params_pspec if isinstance(x, dict) else None + return params_sharding if isinstance(x, dict) else None - state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, dict)) - return state_pspec + state_sharding = jax.tree_util.tree_map( + replace_params, state, is_leaf=lambda x: isinstance(x, dict) + ) + return state_sharding def train_and_evaluate(args): @@ -270,7 +274,9 @@ def train_and_evaluate(args): ), f"Test batch size needs to be multiple of {num_gpu_dp}" device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) - with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)): + with jax.sharding.Mesh( + devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) + ) as mesh: rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) @@ -291,34 +297,39 @@ def train_and_evaluate(args): customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules - params_pspec = get_params_pspec(sharding_rules, abs_var_collect) - inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) - masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) + params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh) + inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None)) + masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None)) - in_shardings = (None, inputs_pspec, masks_pspec) + in_shardings = (None, inputs_sharding, masks_sharding) out_shardings = { - key: params_pspec if key is PARAMS_KEY else None for key in abs_var_collect + key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect } - pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings) - var_collect = pjit_encoder_init(init_rngs, inputs, masks) + jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings) + var_collect = jit_encoder_init(init_rngs, inputs, masks) optimizer = optax.adamw(args.lr) var_collect, params = flax.core.pop(var_collect, PARAMS_KEY) state = train_state.TrainState.create( apply_fn=encoder.apply, params=params, tx=optimizer ) - state_pspec = get_state_pspec(state, params_pspec) - labels_pspec = jax.sharding.PartitionSpec( - DEVICE_DP_AXIS, + state_sharding = get_state_sharding(state, params_sharding) + labels_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS)) + + in_shardings = ( + state_sharding, + inputs_sharding, + masks_sharding, + labels_sharding, + None, + None, ) + out_shardings = (state_sharding, None, None, None) + jit_train_step = jax.jit(train_step, in_shardings, out_shardings) - in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None) - out_shardings = (state_pspec, None, None, None) - pjit_train_step = pjit(train_step, in_shardings, out_shardings) - - in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None) + in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) out_shardings = (None, None) - pjit_eval_step = pjit(eval_step, in_shardings, out_shardings) + jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) @@ -327,7 +338,7 @@ def train_and_evaluate(args): if args.dry_run: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) rngs = {DROPOUT_KEY: dropout_rng} - pjit_train_step(state, inputs, masks, labels, var_collect, rngs) + jit_train_step(state, inputs, masks, labels, var_collect, rngs) print("PASSED") return None @@ -337,11 +348,11 @@ def train_and_evaluate(args): rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} state, train_loss, train_accuracy, var_collect = train_epoch( - state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step + state, train_ds, args.batch_size, rngs, var_collect, jit_train_step ) test_loss, test_accuracy = eval_model( - state, test_ds, args.test_batch_size, var_collect, pjit_eval_step + state, test_ds, args.test_batch_size, var_collect, jit_eval_step ) print( diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 005ae50e72..9d08254f4d 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -17,7 +17,7 @@ from flax.linen import partitioning as nn_partitioning from flax.training import train_state from jax.experimental import mesh_utils -from jax.experimental.pjit import pjit +from jax.sharding import PartitionSpec, NamedSharding import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax @@ -202,32 +202,36 @@ def check_fp8(state, var_collect, inputs, masks, labels): ) -def get_params_pspec(sharding_rules, abs_var_collect): - """Refer params to create params partition spec""" - rules_dict = {} - for key, value in sharding_rules: - rules_dict[key] = value +def get_params_sharding(sharding_rules, abs_var_collect, mesh): + """Refer params to create params sharding""" + rules_dict = dict(sharding_rules) def to_device_axis(logical_axis): partitions = [rules_dict[key] for key in logical_axis] - return jax.sharding.PartitionSpec(*partitions) + return NamedSharding(mesh, PartitionSpec(*partitions)) params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {}) - params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes)) - params_axes_pspec = flax.core.unfreeze(params_axes_pspec) - params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY]) - params_pspec = {**params_pspec, **params_axes_pspec} - return params_pspec + params_axes_sharding = jax.tree_util.tree_map( + to_device_axis, nn_partitioning.get_axis_names(params_axes) + ) + params_axes_sharding = flax.core.unfreeze(params_axes_sharding) + params_sharding = jax.tree_util.tree_map( + lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY] + ) + params_sharding = {**params_sharding, **params_axes_sharding} + return params_sharding -def get_state_pspec(state, params_pspec): - """Refer params_pspec to create state partition spec""" +def get_state_sharding(state, params_sharding): + """Refer params_sharding to create state sharding""" def replace_params(x): - return params_pspec if isinstance(x, dict) else None + return params_sharding if isinstance(x, dict) else None - state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, dict)) - return state_pspec + state_sharding = jax.tree_util.tree_map( + replace_params, state, is_leaf=lambda x: isinstance(x, dict) + ) + return state_sharding def train_and_evaluate(args): @@ -240,7 +244,7 @@ def train_and_evaluate(args): assert args.test_batch_size % num_gpu == 0, f"Test batch size needs to be multiple of {num_gpu}" device_mesh = mesh_utils.create_device_mesh((num_gpu,)) - with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)): + with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)) as mesh: rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) @@ -260,34 +264,43 @@ def train_and_evaluate(args): abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) sharding_rules = te_flax.extend_logical_axis_rules(tuple()) - params_pspec = get_params_pspec(sharding_rules, abs_var_collect) - inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) - masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) + params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh) + inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None)) + masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None)) - in_shardings = (None, inputs_pspec, masks_pspec) + in_shardings = (None, inputs_sharding, masks_sharding) out_shardings = { - key: params_pspec if key is PARAMS_KEY else None for key in abs_var_collect + key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect } - pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings) - var_collect = pjit_encoder_init(init_rngs, inputs, masks) + jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings) + var_collect = jit_encoder_init(init_rngs, inputs, masks) optimizer = optax.adamw(args.lr) var_collect, params = flax.core.pop(var_collect, PARAMS_KEY) state = train_state.TrainState.create( apply_fn=encoder.apply, params=params, tx=optimizer ) - state_pspec = get_state_pspec(state, params_pspec) - labels_pspec = jax.sharding.PartitionSpec( - DEVICE_DP_AXIS, + state_sharding = get_state_sharding(state, params_sharding) + labels_sharding = NamedSharding( + mesh, + PartitionSpec( + DEVICE_DP_AXIS, + ), ) + in_shardings = ( + state_sharding, + inputs_sharding, + masks_sharding, + labels_sharding, + None, + None, + ) + out_shardings = (state_sharding, None, None, None) + jit_train_step = jax.jit(train_step, in_shardings, out_shardings) - in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None) - out_shardings = (state_pspec, None, None, None) - pjit_train_step = pjit(train_step, in_shardings, out_shardings) - - in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None) + in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) out_shardings = (None, None) - pjit_eval_step = pjit(eval_step, in_shardings, out_shardings) + jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) @@ -296,7 +309,7 @@ def train_and_evaluate(args): if args.dry_run: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) rngs = {DROPOUT_KEY: dropout_rng} - pjit_train_step(state, inputs, masks, labels, var_collect, rngs) + jit_train_step(state, inputs, masks, labels, var_collect, rngs) print("PASSED") return None @@ -306,11 +319,11 @@ def train_and_evaluate(args): rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} state, train_loss, train_accuracy, var_collect = train_epoch( - state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step + state, train_ds, args.batch_size, rngs, var_collect, jit_train_step ) test_loss, test_accuracy = eval_model( - state, test_ds, args.test_batch_size, var_collect, pjit_eval_step + state, test_ds, args.test_batch_size, var_collect, jit_eval_step ) print( diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 286c064e96..e581dbc3f9 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -19,7 +19,7 @@ from flax.linen import partitioning as nn_partitioning from flax.training import train_state from jax.experimental import mesh_utils -from jax.experimental.pjit import pjit +from jax.sharding import PartitionSpec, NamedSharding import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax @@ -305,32 +305,36 @@ def check_fp8(state, var_collect, inputs, masks, labels): ) -def get_params_pspec(sharding_rules, abs_var_collect): - """Refer params to create params partition spec""" - rules_dict = {} - for key, value in sharding_rules: - rules_dict[key] = value +def get_params_sharding(sharding_rules, abs_var_collect, mesh): + """Refer params to create params sharding""" + rules_dict = dict(sharding_rules) def to_device_axis(logical_axis): partitions = [rules_dict[key] for key in logical_axis] - return jax.sharding.PartitionSpec(*partitions) + return NamedSharding(mesh, jax.sharding.PartitionSpec(*partitions)) params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {}) - params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes)) - params_axes_pspec = flax.core.unfreeze(params_axes_pspec) - params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY]) - params_pspec = {**params_pspec, **params_axes_pspec} - return params_pspec + params_axes_sharding = jax.tree_util.tree_map( + to_device_axis, nn_partitioning.get_axis_names(params_axes) + ) + params_axes_sharding = flax.core.unfreeze(params_axes_sharding) + params_sharding = jax.tree_util.tree_map( + lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY] + ) + params_sharding = {**params_sharding, **params_axes_sharding} + return params_sharding -def get_state_pspec(state, params_pspec): - """Refer params_pspec to create state partition spec""" +def get_state_sharding(state, params_sharding): + """Refer params_sharding to create state sharding""" def replace_params(x): - return params_pspec if isinstance(x, dict) else None + return params_sharding if isinstance(x, dict) else None - state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, dict)) - return state_pspec + state_sharding = jax.tree_util.tree_map( + replace_params, state, is_leaf=lambda x: isinstance(x, dict) + ) + return state_sharding def train_and_evaluate(args): @@ -362,7 +366,7 @@ def train_and_evaluate(args): device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) with jax.sharding.Mesh( devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) - ) as shard_mesh: + ) as mesh: rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) @@ -383,34 +387,41 @@ def train_and_evaluate(args): customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules - params_pspec = get_params_pspec(sharding_rules, abs_var_collect) + params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh) inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) - in_shardings = (None, inputs_pspec, masks_pspec) + inputs_sharding = NamedSharding(mesh, inputs_pspec) + masks_sharding = NamedSharding(mesh, masks_pspec) + in_shardings = (None, inputs_sharding, masks_sharding) out_shardings = { - key: params_pspec if key is PARAMS_KEY else None for key in abs_var_collect + key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect } - pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings) - var_collect = pjit_encoder_init(init_rngs, inputs, masks) + jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings) + var_collect = jit_encoder_init(init_rngs, inputs, masks) optimizer = optax.adamw(args.lr) var_collect, params = flax.core.pop(var_collect, PARAMS_KEY) state = train_state.TrainState.create( apply_fn=encoder.apply, params=params, tx=optimizer ) - state_pspec = get_state_pspec(state, params_pspec) - labels_pspec = jax.sharding.PartitionSpec( - DEVICE_DP_AXIS, + state_sharding = get_state_sharding(state, params_sharding) + labels_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS)) + + in_shardings = ( + state_sharding, + inputs_sharding, + masks_sharding, + labels_sharding, + None, + None, ) + out_shardings = (state_sharding, None, None, None) + jit_train_step = jax.jit(train_step, in_shardings, out_shardings) - in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None) - out_shardings = (state_pspec, None, None, None) - pjit_train_step = pjit(train_step, in_shardings, out_shardings) - - in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None) + in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) out_shardings = (None, None) - pjit_eval_step = pjit(eval_step, in_shardings, out_shardings) + jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) @@ -419,7 +430,7 @@ def train_and_evaluate(args): if args.dry_run: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) rngs = {DROPOUT_KEY: dropout_rng} - pjit_train_step(state, inputs, masks, labels, var_collect, rngs) + jit_train_step(state, inputs, masks, labels, var_collect, rngs) print("PASSED") else: for epoch in range(1, args.epochs + 1): @@ -433,11 +444,11 @@ def train_and_evaluate(args): args.batch_size, rngs, var_collect, - pjit_train_step, - shard_mesh, + jit_train_step, + mesh, inputs_pspec, masks_pspec, - labels_pspec, + labels_sharding.spec, ) test_loss, test_accuracy = eval_model( @@ -445,11 +456,11 @@ def train_and_evaluate(args): test_ds, args.test_batch_size, var_collect, - pjit_eval_step, - shard_mesh, + jit_eval_step, + mesh, inputs_pspec, masks_pspec, - labels_pspec, + labels_sharding.spec, ) if args.process_id == 0: print( From 2215fa5c7557b66034068816020f9f611019e457 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Date: Fri, 23 Aug 2024 10:06:12 -0700 Subject: [PATCH 12/50] don't put master_param to state if None (#1130) Signed-off-by: Alexandros Koumparoulis --- transformer_engine/pytorch/optimizers/fused_adam.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 322b93a1d8..da0ba3328a 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -219,10 +219,8 @@ def step(self, closure=None, grad_scaler=None): assert ( state["master_param"].shape == p.shape ), "Master weights shape must match model weights shape" - else: - state["master_param"] = None - p_master = state["master_param"] + p_master = state.get("master_param", None) p_grad = p.grad if self.master_weights and p_master is not None and p_master.grad is not None: From 901e5d2b335878aa11f81dab5ddb12fbfad4322a Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 23 Aug 2024 15:00:40 -0700 Subject: [PATCH 13/50] Add support for flash-attn 3 (#1019) * WIP: add fa3 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * WIP: clean up Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP: add benchmarks Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * differentiate func/varlen_func Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix parsing keyword for FA3 and remove bshd->thd conversion for flash_attn_func Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * WIP: add FP8 fwd support Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add FA3 FP8 fwd code and test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix assert for FA3 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix FA3 FP8 logic and add tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update FA2 to <=2.6.3 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak unit tests for base/mask Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * set constraints for FA3 for sm90 and causal_bottom_right Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert debug changes in benchmark script Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- benchmarks/attention/benchmark_attention.py | 2 +- setup.py | 2 +- tests/pytorch/fused_attn/test_fused_attn.py | 333 +++++++++---------- transformer_engine/pytorch/attention.py | 339 +++++++++++++++----- transformer_engine/pytorch/fp8.py | 9 + transformer_engine/pytorch/setup.py | 2 +- transformer_engine/pytorch/softmax.py | 3 - 8 files changed, 425 insertions(+), 267 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 2533f5e5c1..98ca4e1941 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 2533f5e5c1877fd76266133c1479ef1643ce3a8b +Subproject commit 98ca4e1941fe3263f128f74f10063a3ea35c7019 diff --git a/benchmarks/attention/benchmark_attention.py b/benchmarks/attention/benchmark_attention.py index bfd7bf8471..cff7c65fbc 100644 --- a/benchmarks/attention/benchmark_attention.py +++ b/benchmarks/attention/benchmark_attention.py @@ -156,7 +156,7 @@ def parse_results(per_cudnn, per_flash, model): df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] = t_cudnn_avg.sum() / 1e6 if per_flash > 0: - t_flash_all = df[df["Name"].str.contains("void flash")]["Duration (ns)"].to_numpy() + t_flash_all = df[df["Name"].str.contains("flash")]["Duration (ns)"].to_numpy() t_flash_all = t_flash_all.reshape(-1, per_flash) t_flash_avg = np.average(t_flash_all, axis=0) df_times.loc[row, "FlashAttention Kernels (fwd)"] = t_flash_avg[0] / 1e6 diff --git a/setup.py b/setup.py index 6cee4690dc..23ffbd7f2c 100644 --- a/setup.py +++ b/setup.py @@ -92,7 +92,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: # Framework-specific requirements if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "pytorch" in frameworks: - install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.5.8,!=2.0.9,!=2.1.0"]) + install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.6.3,!=2.0.9,!=2.1.0"]) test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"]) if "jax" in frameworks: install_reqs.extend(["jax", "flax>=0.7.1"]) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 82a3c8576b..b3225a71ef 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -420,6 +420,10 @@ def test_dpa_mla(dtype, model_configs, model): "mask_8_1": ModelConfig( 1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" ), + "mask_9_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"), + "mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"), + "mask_10_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), + "mask_10_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), } @@ -1301,6 +1305,7 @@ def _run_transformer_layer( "fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"), "fp8_13": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"), "fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), + "fp8_15": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), } param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16] @@ -1312,6 +1317,27 @@ def _rmse(a, b): return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum()) +def _error(a, b, name_a, name_b, atol, rtol, rmse_tol): + logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item())) + logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item())) + try: + torch.testing.assert_close(a, b, atol=atol, rtol=rtol) + except Exception as e: + logging.debug(e) + + rmse = _rmse(a, b) + logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse)) + rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) + assert rmse < rmse_tol * rmse_range, ( + name_a + + " vs " + + name_b + + " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( + rmse, rmse_tol * rmse_range, rmse_tol, rmse_range + ) + ) + + @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") @@ -1320,86 +1346,74 @@ def _rmse(a, b): @pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16) @pytest.mark.parametrize("input_layernorm", [True, False]) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) -def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd): - os.environ["NVTE_FLASH_ATTN"] = "0" - os.environ["NVTE_FUSED_ATTN"] = "1" +@pytest.mark.parametrize("is_training", [True, False]) +def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, is_training): os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" - global _attention_backends - _attention_backends["backend_selection_requires_update"] = True + os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" config = model_configs_fp8_vs_f16[model] - os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" + global _attention_backends + if not is_training: + os.environ["NVTE_FLASH_ATTN"] = "1" + os.environ["NVTE_FUSED_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") + flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16( + dtype, config, True, qkv_format, input_layernorm, is_training + ) + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( - dtype, config, True, qkv_format, input_layernorm + dtype, config, True, qkv_format, input_layernorm, is_training ) logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False") fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( - dtype, config, False, qkv_format, input_layernorm - ) - - tols = dict(atol=5e-1, rtol=5e-1) - rmse_tol = 0.1 - fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16) - fwd_range = max(fused_attn_fwd_fp8.max().item(), fused_attn_fwd_f16.max().item()) - min( - fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.min().item() + dtype, config, False, qkv_format, input_layernorm, is_training ) + atol = 5e-1 + rtol = 5e-1 + rmse_tol = 0.15 logging.debug("========== {:^25s} ==========".format("forward output")) - logging.debug( - "fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format( - fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item() - ) - ) - logging.debug( - "fused_attn_fwd_f16 min {:.6f} max {:.6f}".format( - fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item() + if not is_training: + _error( + flash_attn_fwd_fp8, + fused_attn_fwd_f16, + "flash_attn_fwd_fp8", + "fused_attn_fwd_f16", + atol, + rtol, + rmse_tol, ) + _error( + fused_attn_fwd_fp8, + fused_attn_fwd_f16, + "fused_attn_fwd_fp8", + "fused_attn_fwd_f16", + atol, + rtol, + rmse_tol, ) - logging.debug("fused_attn_fwd RMSE: {:.6f}".format(fwd_rmse)) - try: - torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols) - except Exception as e: - logging.debug(e) - assert ( - fwd_rmse < rmse_tol * fwd_range - ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range - ) - for i in range(len(param_names[:1])): - bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i]) - bwd_range = max( - fused_attn_bwd_fp8[i].max().item(), fused_attn_bwd_f16[i].max().item() - ) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_f16[i].min().item()) - - logging.debug("========== {:^25s} ==========".format(param_names[i])) - logging.debug( - "fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}".format( - i, fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item() - ) - ) - logging.debug( - "fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}".format( - i, fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item() + if is_training: + for i in range(len(param_names[:1])): + logging.debug("========== {:^25s} ==========".format(param_names[i])) + _error( + fused_attn_bwd_fp8[i], + fused_attn_bwd_f16[i], + f"fused_attn_bwd_fp8[{i}]", + f"fused_attn_bwd_f16[{i}]", + atol, + rtol, + rmse_tol, ) - ) - logging.debug("fused_attn_bwd RMSE[{}]: {:.6f}".format(i, bwd_rmse)) - try: - torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols) - except Exception as e: - logging.debug(e) - - assert ( - bwd_rmse < rmse_tol * bwd_range - ), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range - ) -def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): +def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, is_training): reset_rng_states() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) @@ -1434,6 +1448,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: qkv_weight_interleaved=True, qkv_format=qkv_format, ).to(dtype=dtype, device="cuda") + if not is_training: + mha = mha.eval() seqlens_q = torch.full( [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" @@ -1464,7 +1480,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: tensor_shape = [dim_to_num[j] for j in layout.split("_")] tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda") hidden_states = tensor.view(*tensor.shape[:-2], -1) - hidden_states.requires_grad = True + if is_training: + hidden_states.requires_grad = True tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda") out_grad = tensor.view(*tensor.shape[:-2], -1) @@ -1476,7 +1493,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: core_attention_bias_type=config.attn_bias_type, is_first_microbatch=None, ) - out.backward(out_grad) + if is_training: + out.backward(out_grad) param_names = [] param_names.append("hidden_states.grad") @@ -1487,7 +1505,9 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: param_names.append(name + ".grad") params.append(param) - return out, param_names, tuple(x.grad for x in params) + if is_training: + return out, param_names, tuple(x.grad for x in params) + return out, param_names, tuple(None for x in params) @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") @@ -1497,7 +1517,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: @pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys()) @pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) -def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd): +@pytest.mark.parametrize("is_training", [True, False]) +def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): config = model_configs_fp8_vs_f16[model] if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: @@ -1505,76 +1526,69 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd): os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" + global _attention_backends - _attention_backends["backend_selection_requires_update"] = True + if not is_training: + os.environ["NVTE_FLASH_ATTN"] = "1" + os.environ["NVTE_FUSED_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True") + flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( + dtype, config, True, qkv_layout, is_training + ) + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + _attention_backends["backend_selection_requires_update"] = True logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True") - fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(dtype, config, True, qkv_layout) + fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( + dtype, config, True, qkv_layout, is_training + ) logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False") - fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(dtype, config, False, qkv_layout) + fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( + dtype, config, False, qkv_layout, is_training + ) - tols = dict(atol=5e-1, rtol=5e-2) + atol = 5e-1 + rtol = 5e-2 rmse_tol = 0.1 bwd_names = ["dq", "dk", "dv"] - fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16) - fwd_range = max(fused_attn_fwd_fp8.max().item(), fused_attn_fwd_f16.max().item()) - min( - fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.min().item() - ) - logging.debug("========== {:^25s} ==========".format("forward output")) - logging.debug( - "fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format( - fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item() + if not is_training: + _error( + flash_attn_fwd_fp8, + fused_attn_fwd_f16, + "flash_attn_fwd_fp8", + "fused_attn_fwd_f16", + atol, + rtol, + rmse_tol, ) + _error( + fused_attn_fwd_fp8, + fused_attn_fwd_f16, + "fused_attn_fwd_fp8", + "fused_attn_fwd_f16", + atol, + rtol, + rmse_tol, ) - logging.debug( - "fused_attn_fwd_f16 min {:.6f} max {:.6f}".format( - fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item() - ) - ) - logging.debug("fused_attn_fwd RMSE: {:.6f}".format(fwd_rmse)) - try: - torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols) - except Exception as e: - logging.debug(e) - - assert ( - fwd_rmse < rmse_tol * fwd_range - ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range - ) - for i, _ in enumerate(fused_attn_bwd_f16): - bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i]) - bwd_range = max( - fused_attn_bwd_fp8[i].max().item(), fused_attn_bwd_f16[i].max().item() - ) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_f16[i].min().item()) - - logging.debug("========== {:^25s} ==========".format(bwd_names[i])) - logging.debug( - "fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}".format( - i, fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item() - ) - ) - logging.debug( - "fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}".format( - i, fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item() + if is_training: + for i, _ in enumerate(fused_attn_bwd_f16): + logging.debug("========== {:^25s} ==========".format(bwd_names[i])) + _error( + fused_attn_bwd_fp8[i], + fused_attn_bwd_f16[i], + f"fused_attn_bwd_fp8[{i}]", + f"fused_attn_bwd_f16[{i}]", + atol, + rtol, + rmse_tol, ) - ) - logging.debug("fused_attn_bwd RMSE[{}]: {:.6f}".format(i, bwd_rmse)) - try: - torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols) - except Exception as e: - logging.debug(e) - - assert ( - bwd_rmse < rmse_tol * bwd_range - ), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range - ) -def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout): +def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training): reset_rng_states() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() @@ -1607,6 +1621,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: attention_type="self", qkv_format=qkv_format, ).to(dtype=dtype, device="cuda") + if not is_training: + dpa = dpa.eval() seqlens_q = torch.full( [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" @@ -1680,9 +1696,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: core_attention_bias_type=config.attn_bias_type, is_first_microbatch=True, ) - out.backward(out_grad) + if is_training: + out.backward(out_grad) - return out, (inp[0].grad, inp[1].grad, inp[2].grad) + if is_training: + return out, (inp[0].grad, inp[1].grad, inp[2].grad) + return out, (None, None, None) model_configs_fp8 = { @@ -1726,58 +1745,26 @@ def test_custom_mha_fp8_vs_f16(dtype, model): fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention") unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention") - tols = dict(atol=5e-1, rtol=5e-1) + atol = 5e-1 + rtol = 5e-1 rmse_tol = 0.1 - fwd_rmse = _rmse(fused_attn_fwd_fp8, unfused_attn_fwd_f16) - fwd_range = max(fused_attn_fwd_fp8.max().item(), unfused_attn_fwd_f16.max().item()) - min( - fused_attn_fwd_fp8.min().item(), unfused_attn_fwd_f16.min().item() - ) - bwd_rmse = _rmse(fused_attn_bwd_fp8, unfused_attn_bwd_f16) - bwd_range = max(fused_attn_bwd_fp8.max().item(), unfused_attn_bwd_f16.max().item()) - min( - fused_attn_bwd_fp8.min().item(), unfused_attn_bwd_f16.min().item() - ) - - logging.debug( - "fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format( - fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item() - ) - ) - logging.debug( - "unfused_attn_fwd_f16 min {:.6f} max {:.6f}".format( - unfused_attn_fwd_f16.min().item(), unfused_attn_fwd_f16.max().item() - ) - ) - logging.debug("fused_attn_fwd_fp8 vs unfused_attn_fwd_f16 RMSE: {:.6f}".format(fwd_rmse)) - try: - torch.testing.assert_close(fused_attn_fwd_fp8, unfused_attn_fwd_f16, **tols) - except Exception as e: - logging.debug(e) - - logging.debug( - "fused_attn_bwd_fp8 min {:.6f} max {:.6f}".format( - fused_attn_bwd_fp8.min().item(), fused_attn_bwd_fp8.max().item() - ) - ) - logging.debug( - "unfused_attn_bwd_f16 min {:.6f} max {:.6f}".format( - unfused_attn_bwd_f16.min().item(), unfused_attn_bwd_f16.max().item() - ) - ) - logging.debug("fused_attn_bwd_fp8 vs unfused_attn_bwd_f16 RMSE: {:.6f}".format(bwd_rmse)) - try: - torch.testing.assert_close(fused_attn_bwd_fp8, unfused_attn_bwd_f16, **tols) - except Exception as e: - logging.debug(e) - - assert ( - fwd_rmse < rmse_tol * fwd_range - ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range + _error( + fused_attn_fwd_fp8, + unfused_attn_fwd_f16, + "fused_attn_fwd_fp8", + "unfused_attn_fwd_f16", + atol, + rtol, + rmse_tol, ) - assert ( - bwd_rmse < rmse_tol * bwd_range - ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range + _error( + fused_attn_bwd_fp8, + unfused_attn_bwd_f16, + "fused_attn_bwd_fp8", + "unfused_attn_bwd_f16", + atol, + rtol, + rmse_tol, ) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6a46d6c3c1..5c7f361c0a 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -6,6 +6,7 @@ import collections from contextlib import nullcontext from importlib.metadata import version as get_pkg_version +from importlib.metadata import PackageNotFoundError import math import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -38,7 +39,7 @@ AttnMaskType, FusedAttnBackend, ) -from transformer_engine.pytorch.fp8 import get_fp8_te_dtype +from transformer_engine.pytorch.fp8 import get_fp8_te_dtype, get_fp8_torch_dtype from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.module import LayerNormLinear, Linear from transformer_engine.pytorch.module.base import TransformerEngineBaseModule @@ -75,16 +76,42 @@ _flash_attn_version = PkgVersion(get_pkg_version("flash-attn")) _flash_attn_version_required = PkgVersion("2.0.6") -_flash_attn_max_version = PkgVersion("2.5.8") +_flash_attn_max_version = PkgVersion("2.6.3") _flash_attn_2_plus = _flash_attn_version >= PkgVersion("2") _flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1") _flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3") _flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4") _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") _flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7") +_flash_attn_3_plus = False +_use_flash_attn_3 = False +try: + _flash_attn_v3_version = PkgVersion(get_pkg_version("flashattn-hopper")) + _flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.6.1") +except PackageNotFoundError: + warnings.warn( + "To use flash-attn v3, please use the following commands to install: \n" + """(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" \n""" + """(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` \n""" + """(3) mkdir -p $python_path/flashattn_hopper \n""" + """(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py""" + ) +else: + from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3 + from flashattn_hopper.flash_attn_interface import ( + flash_attn_varlen_func as flash_attn_varlen_func_v3, + ) + from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import + _flash_attn_forward as _flash_attn_forward_v3, + ) + from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import + _flash_attn_backward as _flash_attn_backward_v3, + ) + + _use_flash_attn_3 = True if _flash_attn_version >= _flash_attn_version_required: - from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func + from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd @@ -318,6 +345,7 @@ def get_attention_backend( use_fused_attention = False # Filter: Compute capability + global _flash_attn_3_plus, _use_flash_attn_3 if device_compute_capability < (8, 0): if use_flash_attention: logger.debug("Disabling FlashAttention as it requires compute capability sm80+") @@ -325,32 +353,37 @@ def get_attention_backend( if use_fused_attention: logger.debug("Disabling FusedAttention as it requires compute capability sm80+") use_fused_attention = False + if device_compute_capability < (9, 0): + if use_flash_attention and _flash_attn_3_plus: + logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+") + _use_flash_attn_3 = False # Filter: Data type - if use_flash_attention and ( - qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type == Float8Tensor - ): - logger.debug( - "Disabling FlashAttention due to unsupported QKV data type. " - "Supported: qkv_type = torch.Tensor, qkv_dtype = {torch.bfloat16, torch.float16}. " - "Found: qkv_type = %s, qkv_dtype = %s.", - qkv_type, - qkv_dtype, - ) - use_flash_attention = False - if use_fused_attention and (qkv_dtype not in [torch.bfloat16, torch.float16]): - logger.debug( - "Disabling FusedAttention due to unsupported QKV data type. " - "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. " - "Found: qkv_dtype = %s.", - qkv_dtype, - ) - use_fused_attention = False + if qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type not in [ + torch.Tensor, + Float8Tensor, + ]: + if use_flash_attention: + logger.debug( + "Disabling FlashAttention due to unsupported QKV data type. " + "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. " + "Found: qkv_dtype = %s.", + qkv_dtype, + ) + use_flash_attention = False + if use_fused_attention: + logger.debug( + "Disabling FusedAttention due to unsupported QKV data type. " + "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. " + "Found: qkv_dtype = %s.", + qkv_dtype, + ) + use_fused_attention = False # Filter: Execution type if fp8 and fp8_meta["recipe"].fp8_dpa: - if use_flash_attention: - logger.debug("Disabling FlashAttention as it does not support FP8") + if use_flash_attention and is_training: + logger.debug("Disabling FlashAttention as it does not support FP8 training") use_flash_attention = False if use_unfused_attention: logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8") @@ -396,6 +429,12 @@ def get_attention_backend( ) use_flash_attention = False + # Filter: Dropout + if attention_dropout != 0.0 and use_flash_attention: + if _flash_attn_3_plus and _use_flash_attn_3: + logger.debug("Disabling FlashAttention 3 for dropout") + _use_flash_attn_3 = False + # Filter: Context parallelism # qkv_format | attn_mask_type | attn_bias_type | supported backends # ---------------------------------------------------------------------------------------------------- @@ -414,6 +453,14 @@ def get_attention_backend( ) use_unfused_attention = False if context_parallel and use_flash_attention: + if _flash_attn_3_plus and _use_flash_attn_3: + logger.debug("Disabling FlashAttention 3 for context parallelism") + _use_flash_attn_3 = False + if fp8 and fp8_meta["recipe"].fp8_dpa: + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with FP8" + ) + use_flash_attention = False if "bottom_right" in attn_mask_type: logger.debug( "Disabling FlashAttention as it does not support context parallelism with" @@ -439,6 +486,7 @@ def get_attention_backend( " bias for THD format" ) use_flash_attention = False + if context_parallel and use_fused_attention: if "bottom_right" in attn_mask_type: logger.debug( @@ -498,6 +546,18 @@ def get_attention_backend( if use_fused_attention: logger.debug("Disabling FusedAttention for arbitrary mask") use_fused_attention = False + if ( + use_flash_attention + and _flash_attn_3_plus + and attn_mask_type in ["causal", "padding_causal"] + and max_seqlen_q != max_seqlen_kv + ): + logger.warning( + "Disabling FlashAttention 3 as it only supports bottom-right-diagonal " + "causal mask since flash-attn 2.1. See " + "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" + ) + _use_flash_attn_3 = False if ( use_flash_attention and _flash_attn_2_1_plus @@ -571,6 +631,15 @@ def get_attention_backend( attn_mask_type, ) use_fused_attention = False + if ( + use_flash_attention + and (window_size[0] != -1 or window_size[1] not in [-1, 0]) + and _flash_attn_3_plus + ): + logger.debug( + "Disabling FlashAttention 3 as it does not support sliding window attention" + ) + _use_flash_attn_3 = False if ( use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]) @@ -590,6 +659,14 @@ def get_attention_backend( # | | bottom_right (converts to a 'post_scale_bias' bias) # UnfusedDotProductAttention | no_bias, pre/post_scale_bias | # | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias + if use_flash_attention and core_attention_bias_type == "alibi": + if _flash_attn_3_plus and _use_flash_attn_3: + logger.debug("Disabling FlashAttention 3 for ALiBi") + _use_flash_attn_3 = False + if not _flash_attn_2_4_plus: + logger.debug("Disabling FlashAttention for ALiBi") + use_flash_attention = False + if use_flash_attention and ( core_attention_bias_type not in ["no_bias", "alibi"] or core_attention_bias_shape is not None @@ -1071,7 +1148,7 @@ def _get_full_cu_seqlens( return _cu_seqlens_cache[(batch_size, max_seqlen)] -@jit_fuser +@torch.compile def pack_tensor( indices: torch.Tensor, tensor: torch.Tensor, @@ -1082,14 +1159,19 @@ def pack_tensor( padding_indice = torch.zeros( 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device ) - tensor = torch.cat((tensor, padding_indice), dim=0) - indices = indices.repeat(1, tensor.shape[1], tensor.shape[2]) - packed = torch.gather(tensor, 0, indices) + if isinstance(tensor, Float8Tensor): + tensor_data = torch.cat((tensor._data, padding_indice), dim=0) + + packed = Float8Tensor.make_like(tensor, data=torch.gather(tensor_data, 0, indices)) + else: + tensor = torch.cat((tensor, padding_indice), dim=0) + + packed = torch.gather(tensor, 0, indices) return packed -@jit_fuser +@torch.compile def pack_2_tensors( indices: torch.Tensor, t1: torch.Tensor, @@ -1103,7 +1185,7 @@ def pack_2_tensors( return t1_packed, t2_packed -@jit_fuser +@torch.compile def pack_3_tensors( indices: torch.Tensor, t1: torch.Tensor, @@ -1119,7 +1201,7 @@ def pack_3_tensors( return t1_packed, t2_packed, t3_packed -@jit_fuser +@torch.compile def unpack_tensor( indices: torch.Tensor, dim0: int, @@ -1132,12 +1214,16 @@ def unpack_tensor( unpacked = torch.zeros( dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device ) - unpacked.scatter_(0, indices, tensor) - unpacked = unpacked[0:-1, :, :] + if isinstance(tensor, Float8Tensor): + unpacked.scatter_(0, indices, tensor._data) + unpacked = Float8Tensor.make_like(tensor, data=unpacked[0:-1, :, :]) + else: + unpacked.scatter_(0, indices, tensor) + unpacked = unpacked[0:-1, :, :] return unpacked -@jit_fuser +@torch.compile def unpack_2_tensors( indices: torch.Tensor, dim0: int, @@ -1152,7 +1238,7 @@ def unpack_2_tensors( return t1_unpacked, t2_unpacked -@jit_fuser +@torch.compile def unpack_3_tensors( indices: torch.Tensor, dim0: int, @@ -4212,14 +4298,15 @@ def forward( cp_global_ranks: List[int] = None, cp_stream: torch.cuda.Stream = None, cp_comm_type: str = "p2p", + fp8: bool = False, + fp8_meta: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: """flash-attn fprop""" - assert ( - query_layer.dtype in [torch.float16, torch.bfloat16] - and key_layer.dtype in [torch.float16, torch.bfloat16] - and value_layer.dtype in [torch.float16, torch.bfloat16] - ), "FlashAttention currently only supports FP16 and BF16." + assert all( + x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor) + for x in [query_layer, key_layer, value_layer] + ), "FlashAttention only supports FP16 and BF16 data types, or Float8Tensors." assert ( query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda ), "FlashAttention currently only supports CUDA tensors." @@ -4232,24 +4319,36 @@ def forward( qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - if qkv_format == "sbhd": - # For now just 128, will make it more general in the future - if ( - query_layer.shape[-1] == 128 - and query_layer.shape[0] * query_layer.shape[1] >= 512 - and qkv_layout == "sbh3d" - ): - query_layer, key_layer, value_layer = _PrepareQKVForFA.apply( - query_layer, key_layer, value_layer - ) - else: + if all(not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): + if qkv_format == "sbhd": + # For now just 128, will make it more general in the future + if ( + query_layer.shape[-1] == 128 + and query_layer.shape[0] * query_layer.shape[1] >= 512 + and qkv_layout == "sbh3d" + ): + query_layer, key_layer, value_layer = _PrepareQKVForFA.apply( + query_layer, key_layer, value_layer + ) + else: + query_layer, key_layer, value_layer = [ + x.transpose(0, 1).contiguous() + for x in (query_layer, key_layer, value_layer) + ] + elif qkv_format in ["bshd", "thd"]: query_layer, key_layer, value_layer = [ - x.transpose(0, 1).contiguous() for x in (query_layer, key_layer, value_layer) + x.contiguous() for x in (query_layer, key_layer, value_layer) + ] + else: + if qkv_format == "sbhd": + query_layer._data, key_layer._data, value_layer._data = [ + x.transpose(0, 1).contiguous() + for x in (query_layer._data, key_layer._data, value_layer._data) + ] + elif qkv_format in ["bshd", "thd"]: + query_layer._data, key_layer._data, value_layer._data = [ + x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data) ] - elif qkv_format in ["bshd", "thd"]: - query_layer, key_layer, value_layer = [ - x.contiguous() for x in (query_layer, key_layer, value_layer) - ] batch_size = query_layer.shape[0] @@ -4257,16 +4356,15 @@ def forward( max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1] max_seqlen_q *= cp_size max_seqlen_kv *= cp_size - if not context_parallel: + + if "padding" in attn_mask_type: + assert not context_parallel, "Padding mask not supported with context parallelism!" # [b * s, h, d] query_layer, key_layer, value_layer = [ - x.view(x.shape[0] * x.shape[1], *x.shape[2:]) + x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]) for x in [query_layer, key_layer, value_layer] ] - if "padding" in attn_mask_type: - assert not context_parallel, "Padding mask not supported with context parallelism!" - if self.attention_type == "self": assert ( max_seqlen_q == max_seqlen_kv @@ -4319,7 +4417,9 @@ def forward( seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] max_seqlen_kv = seqlens_kv.max().item() - if context_parallel: + if context_parallel and all( + not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer] + ): assert ( alibi_slopes is None ), "Alibi slope bias addition is not supported with context parallelism." @@ -4366,34 +4466,94 @@ def forward( fa_optional_forward_kwargs["deterministic"] = self.deterministic if _flash_attn_2_5_7_plus: fa_optional_forward_kwargs["block_table"] = None - output = flash_attn_forward_func( - query_layer, - key_layer, - value_layer, - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, - self.attention_dropout if self.training else 0.0, - softmax_scale=self.softmax_scale, - causal="causal" in attn_mask_type, - **fa_optional_forward_kwargs, - ) + fa_optional_forward_args_thd = [] + if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type: + func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3 + else: + func = ( + flash_attn_varlen_func + if not _use_flash_attn_3 + else flash_attn_varlen_func_v3 + ) + fa_optional_forward_args_thd.append(cu_seqlens_q) + fa_optional_forward_args_thd.append(cu_seqlens_kv) + fa_optional_forward_args_thd.append(max_seqlen_q) + fa_optional_forward_args_thd.append(max_seqlen_kv) + if _use_flash_attn_3: + if fp8: + fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + activation_dtype = query_layer.dtype + torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True) + if fp8_meta["recipe"].fp8_mha: + assert all( + isinstance(x, Float8Tensor) + for x in [query_layer, key_layer, value_layer] + ), "q/k/v must be Float8Tensors for FP8 MHA." + fp8_meta["scaling_fwd"].scale_inv[META_QKV] = query_layer._scale_inv + query_layer, key_layer, value_layer = ( + x.to(activation_dtype).to(torch_dtype) + for x in [query_layer, key_layer, value_layer] + ) + else: + query_layer, key_layer, value_layer = ( + x.to(torch_dtype) for x in [query_layer, key_layer, value_layer] + ) + output, _ = func( + query_layer, + key_layer, + value_layer, + *fa_optional_forward_args_thd, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + deterministic=self.deterministic, + ) + if fp8 and fp8_meta["recipe"].fp8_mha: + output = cast_to_fp8( + output, + fp8_meta["scaling_fwd"], + META_O, + fp8_dtype_forward, + ) + output = Float8Tensor( + data=output, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=META_O, + fp8_dtype=fp8_dtype_forward, + dtype=activation_dtype, + ) + else: + output = func( + query_layer, + key_layer, + value_layer, + *fa_optional_forward_args_thd, + self.attention_dropout if self.training else 0.0, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + **fa_optional_forward_kwargs, + ) if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type: output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output) if qkv_format == "sbhd": # (bs)hd -> bs(hd) -> sb(hd) - output = ( - output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1).contiguous() - ) + if fp8 and fp8_meta["recipe"].fp8_mha: + output.reshape(batch_size * max_seqlen_q // cp_size, -1).transpose_2d() + output = output.reshape(batch_size, max_seqlen_q // cp_size, -1) + else: + output = ( + output.view(batch_size, max_seqlen_q // cp_size, -1) + .transpose(0, 1) + .contiguous() + ) elif qkv_format == "bshd": # (bs)hd -> bs(hd) - output = output.view(batch_size, max_seqlen_q // cp_size, -1).contiguous() + output = output.reshape(batch_size, max_seqlen_q // cp_size, -1) elif qkv_format == "thd": # thd -> t(hd) - output = output.view(output.shape[0], -1).contiguous() + output = output.reshape(output.shape[0], -1) return output @@ -5897,11 +6057,10 @@ def forward( assert ( fused_attention_backend != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend ), "No fused attention backend supports this input combination!" - assert ( - (query_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8]) - and (key_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8]) - and (value_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8]) - ), "FusedAttention only supports FP16 and BF16 data types." + assert all( + x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor) + for x in [query_layer, key_layer, value_layer] + ), "FusedAttention only supports FP16 and BF16 data types, or Float8Tensors." assert ( query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda ), "FusedAttention only supports CUDA tensors." @@ -6812,7 +6971,7 @@ def forward( fp8=self.fp8, fp8_meta=self.fp8_meta, ) - global _attention_backends + global _attention_backends, _flash_attn_3_plus, _use_flash_attn_3 if ( _attention_backends["attention_params"] is None or attention_params != _attention_backends["attention_params"] @@ -6820,6 +6979,7 @@ def forward( _attention_backends["attention_params"] = attention_params _attention_backends["backend_selection_requires_update"] = True if _attention_backends["backend_selection_requires_update"]: + _use_flash_attn_3 = _flash_attn_3_plus ( use_flash_attention, use_fused_attention, @@ -6828,7 +6988,10 @@ def forward( _, ) = get_attention_backend(attention_params) if use_flash_attention: - self.logger.info("Running with FlashAttention backend") + self.logger.info( + "Running with FlashAttention backend (version %s)", + _flash_attn_version if not _use_flash_attn_3 else _flash_attn_v3_version, + ) elif use_fused_attention: self.logger.info( "Running with FusedAttention backend (sub-backend %s)", @@ -6867,6 +7030,8 @@ def forward( cp_comm_type=self.cp_comm_type, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, + fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, + fp8_meta=self.fp8_meta, ) if use_fused_attention: diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index e15268b998..76679eb064 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -38,6 +38,15 @@ def get_default_fp8_recipe() -> DelayedScaling: return DelayedScaling() +def get_fp8_torch_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> torch.dtype: + """Get fp8 data type according to recipe and tensor""" + if fp8_recipe.fp8_format == Format.E4M3 or ( + fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor + ): + return torch.float8_e4m3fn + return torch.float8_e5m2fn + + def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType: """Get fp8 data type according to recipe and tensor""" if fp8_recipe.fp8_format == Format.E4M3 or ( diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index e2f15d5d89..034e671150 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -56,7 +56,7 @@ description="Transformer acceleration library - Torch Lib", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension}, - install_requires=["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"], + install_requires=["torch", "flash-attn>=2.0.6,<=2.6.3,!=2.0.9,!=2.1.0"], tests_require=["numpy", "onnxruntime", "torchvision"], ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): diff --git a/transformer_engine/pytorch/softmax.py b/transformer_engine/pytorch/softmax.py index 4fb8a28857..a632851a76 100644 --- a/transformer_engine/pytorch/softmax.py +++ b/transformer_engine/pytorch/softmax.py @@ -23,9 +23,6 @@ def _get_default_causal_mask(mask_type: str, sq: int, sk: int) -> torch.Tensor: """Return the causal upper triangular mask for softmax input""" - if sq == 1: - return torch.zeros((1, sk), dtype=torch.bool, device="cuda") - matrix_identifiers = (mask_type, sq, sk) if matrix_identifiers not in _default_causal_mask: diagonal_offset = sk - sq + 1 if "bottom_right" in mask_type else 1 From 4ec66c77752f716188eeb20059d72917946ea6b0 Mon Sep 17 00:00:00 2001 From: hXl3s Date: Sat, 24 Aug 2024 02:01:57 +0200 Subject: [PATCH 14/50] Let user limit number of architectures, to improve build time (#1126) * Limit number of architectures build Signed-off-by: Lukasz Pierscieniewski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Lukasz Pierscieniewski Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- build_tools/pytorch.py | 19 ++++++++++++------- build_tools/utils.py | 7 ++++++- setup.py | 9 +++++---- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index f932f0695e..3725e58c87 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -10,8 +10,9 @@ from .utils import ( all_files_in_dir, - cuda_version, + cuda_archs, cuda_path, + cuda_version, ) @@ -48,8 +49,6 @@ def setup_pytorch_extension( ] nvcc_flags = [ "-O3", - "-gencode", - "arch=compute_70,code=sm_70", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_BFLOAT16_OPERATORS__", @@ -61,6 +60,11 @@ def setup_pytorch_extension( "--use_fast_math", ] + cuda_architectures = cuda_archs() + + if "70" in cuda_architectures: + nvcc_flags.extend(["-gencode", "arch=compute_70,code=sm_70"]) + # Version-dependent CUDA options try: version = cuda_version() @@ -73,13 +77,14 @@ def setup_pytorch_extension( ( "--threads", os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"), - "-gencode", - "arch=compute_80,code=sm_80", - "-gencode", - "arch=compute_90,code=sm_90", ) ) + if "80" in cuda_architectures: + nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"]) + if "90" in cuda_architectures: + nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"]) + # Libraries library_dirs = [] libraries = [] diff --git a/build_tools/utils.py b/build_tools/utils.py index 81b9a896cb..964a445bc4 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -6,12 +6,12 @@ import functools import glob +import importlib import os import re import shutil import subprocess import sys -import importlib from pathlib import Path from subprocess import CalledProcessError from typing import List, Optional, Tuple, Union @@ -188,6 +188,11 @@ def cuda_path() -> Tuple[str, str]: return cuda_home, nvcc_bin +@functools.lru_cache(maxsize=None) +def cuda_archs() -> str: + return os.getenv("NVTE_CUDA_ARCHS", "70;80;89;90") + + def cuda_version() -> Tuple[int, ...]: """CUDA Toolkit version as a (major, minor) tuple.""" # Query NVCC for version info diff --git a/setup.py b/setup.py index 23ffbd7f2c..4e5359e9c0 100644 --- a/setup.py +++ b/setup.py @@ -13,17 +13,17 @@ from wheel.bdist_wheel import bdist_wheel from build_tools.build_ext import CMakeExtension, get_build_ext +from build_tools.te_version import te_version from build_tools.utils import ( + cuda_archs, found_cmake, found_ninja, found_pybind11, - remove_dups, get_frameworks, install_and_import, + remove_dups, uninstall_te_fw_packages, ) -from build_tools.te_version import te_version - frameworks = get_frameworks() current_file_path = Path(__file__).parent.resolve() @@ -59,10 +59,11 @@ def setup_common_extension() -> CMakeExtension: """Setup CMake extension for common library""" # Project directory root root_path = Path(__file__).resolve().parent + return CMakeExtension( name="transformer_engine", cmake_path=root_path / Path("transformer_engine/common"), - cmake_flags=[], + cmake_flags=["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())], ) From 7fc50f489b8184fbd93efd4e48140ad0264e362b Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 23 Aug 2024 21:13:43 -0700 Subject: [PATCH 15/50] Bump cudnn-frontend version to 1.6.1 (#1133) bump cudnn-frontend version Signed-off-by: Kirthi Shankar Sivamani --- 3rdparty/cudnn-frontend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 98ca4e1941..2533f5e5c1 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 98ca4e1941fe3263f128f74f10063a3ea35c7019 +Subproject commit 2533f5e5c1877fd76266133c1479ef1643ce3a8b From 4ddb0a7bea787294282d0fe0715adf5ea4a39779 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 27 Aug 2024 06:50:06 -0700 Subject: [PATCH 16/50] Hide non-necessary symbols from shared object (#1136) Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/CMakeLists.txt | 4 ++++ transformer_engine/common/libtransformer_engine.version | 4 ++++ 2 files changed, 8 insertions(+) create mode 100644 transformer_engine/common/libtransformer_engine.version diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 06bfec49b4..a6fd6815c3 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -15,6 +15,10 @@ if (CMAKE_BUILD_TYPE STREQUAL "Debug") set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G") endif() +# Hide non-necessary symbols in shared object. +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version") + # Transformer Engine library project(transformer_engine LANGUAGES CUDA CXX) diff --git a/transformer_engine/common/libtransformer_engine.version b/transformer_engine/common/libtransformer_engine.version new file mode 100644 index 0000000000..0683ec01ea --- /dev/null +++ b/transformer_engine/common/libtransformer_engine.version @@ -0,0 +1,4 @@ +{ + global: *nvte*; *transformer_engine*; + local: *; +}; From 8ddac3df41c0304ac9efe0c1e9b23c93326979a8 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Fri, 30 Aug 2024 06:06:53 +0800 Subject: [PATCH 17/50] [PyTorch] Remove `dtype` from args of permutation (#1145) * remove dtype from args * update docs with permutation ops --------- Signed-off-by: Xin Yao --- docs/api/pytorch.rst | 4 ++ qa/L0_pytorch_unittest/test.sh | 1 + tests/pytorch/test_permutation.py | 10 ++- transformer_engine/pytorch/permutation.py | 83 +++++++++-------------- 4 files changed, 41 insertions(+), 57 deletions(-) diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index c9504c20af..a210019dc1 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -44,3 +44,7 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.make_graphed_callables .. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context + +.. autoapifunction:: transformer_engine.pytorch.moe_permute + +.. autoapifunction:: transformer_engine.pytorch.moe_unpermute diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index e6ccf3b82f..b69aed6648 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -23,3 +23,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops_distributed.py +pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index 99bd706b45..ed25b96955 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -220,9 +220,7 @@ def _test_permutation( te_permute_fwd_input.requires_grad_(True) te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach() - te_permute_output, row_id_map = te_permute( - te_permute_fwd_input, te_dtype, indices, num_out_tokens - ) + te_permute_output, row_id_map = te_permute(te_permute_fwd_input, indices, num_out_tokens) te_permute_output.backward(te_permute_bwd_input, retain_graph=True) te_probs = None @@ -233,7 +231,7 @@ def _test_permutation( te_unpermute_fwd_input.requires_grad_(True) te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() - te_unpermute_output = te_unpermute(te_unpermute_fwd_input, te_dtype, row_id_map, te_probs) + te_unpermute_output = te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs) te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) ################################################################################################################################### @@ -305,7 +303,7 @@ def backward_wrapper( lambda: pytorch_permute(pytorch_permute_fwd_input, indices, num_out_tokens) ) t2 = perf_test_cuda_kernel( - lambda: te_permute(te_permute_fwd_input, te_dtype, indices, num_out_tokens) + lambda: te_permute(te_permute_fwd_input, indices, num_out_tokens) ) print(f"permute\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") @@ -333,7 +331,7 @@ def backward_wrapper( lambda: pytorch_unpermute(pytorch_unpermute_fwd_input, sorted_indices, probs=probs) ) t2 = perf_test_cuda_kernel( - lambda: te_unpermute(te_unpermute_fwd_input, te_dtype, row_id_map, te_probs) + lambda: te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs) ) print(f"unpermute\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 0c098830a9..9987db58e0 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -8,7 +8,8 @@ import torch import transformer_engine_torch as tex -from transformer_engine.pytorch.float8_tensor import Float8Tensor +from .constants import TE_DType +from .float8_tensor import Float8Tensor __all__ = [ @@ -27,14 +28,13 @@ class _moe_permute(torch.autograd.Function): def forward( ctx, inp: torch.Tensor, - dtype: tex.DType, indices: torch.Tensor, num_out_tokens: int, max_token_num: int, ) -> Tuple[torch.Tensor, torch.Tensor]: # Empty input check if not inp.numel(): - return inp, None + return inp, torch.tensor([], device=inp.device) # Device check assert inp.is_cuda, "TransformerEngine needs CUDA." @@ -43,16 +43,13 @@ def forward( assert inp.size(0) == indices.size(0), "Permute not possible" # Data type check - fp8 = False - if dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: - fp8 = True + fp8 = isinstance(inp, Float8Tensor) if fp8: - assert isinstance( - inp, Float8Tensor - ), "Input must be in Float8Tensor type for FP8 moe_permute." - fp8_dtype = inp._fp8_dtype + dtype = inp._fp8_dtype fp8_scale_inv = inp._scale_inv inp = inp._data + else: + dtype = TE_DType[inp.dtype] if indices.dtype != torch.int32: warnings.warn( f"The data type of the input `indices` of Permute is {indices.dtype}! " @@ -78,13 +75,12 @@ def forward( if fp8: permuted_act = Float8Tensor( - data=permuted_act, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv + data=permuted_act, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv ) ctx.row_id_map = row_id_map ctx.num_tokens = indices.size(0) ctx.topK = indices.size(1) - ctx.dtype = dtype ctx.fp8 = fp8 return permuted_act, row_id_map @@ -101,30 +97,27 @@ def backward( if not permuted_act_grad.is_contiguous(): permuted_act_grad = permuted_act_grad.contiguous() - fp8 = ctx.fp8 - if fp8: + if ctx.fp8: assert isinstance( permuted_act_grad, Float8Tensor ), "Grad of the output must be in Float8Tensor type for FP8 moe_permute." - fp8_dtype = permuted_act_grad._fp8_dtype + dtype = permuted_act_grad._fp8_dtype fp8_scale_inv = permuted_act_grad._scale_inv permuted_act_grad = permuted_act_grad._data - - row_id_map = ctx.row_id_map - num_tokens = ctx.num_tokens - topK = ctx.topK + else: + dtype = TE_DType[permuted_act_grad.dtype] act_grad = None if ctx.needs_input_grad[0]: act_grad = tex.moe_permute_bwd( - permuted_act_grad, ctx.dtype, row_id_map, torch.empty(0), num_tokens, topK + permuted_act_grad, dtype, ctx.row_id_map, torch.empty(0), ctx.num_tokens, ctx.topK ) - if fp8: + if ctx.fp8: act_grad = Float8Tensor( - data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv * topK + data=act_grad, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv * ctx.topK ) - return act_grad, None, None, None, None + return act_grad, None, None, None class _moe_unpermute(torch.autograd.Function): @@ -134,7 +127,6 @@ class _moe_unpermute(torch.autograd.Function): def forward( ctx, inp: torch.Tensor, - dtype: tex.DType, row_id_map: torch.Tensor, probs: torch.Tensor, ) -> torch.Tensor: @@ -166,16 +158,13 @@ def forward( assert row_id_map.is_cuda, "TransformerEngine needs CUDA." # Data type check - fp8 = False - if dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: - fp8 = True + fp8 = isinstance(inp, Float8Tensor) if fp8: - assert isinstance( - inp, Float8Tensor - ), "Input must be in Float8Tensor type for FP8 moe_unpermute." - fp8_dtype = inp._fp8_dtype + dtype = inp._fp8_dtype fp8_scale_inv = inp._scale_inv inp = inp._data + else: + dtype = TE_DType[inp.dtype] if row_id_map.dtype != torch.int32: warnings.warn( f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! " @@ -187,10 +176,9 @@ def forward( if fp8: unpermuted_output = Float8Tensor( - data=unpermuted_output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv + data=unpermuted_output, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv ) - ctx.dtype = dtype ctx.save_for_backward(inp, row_id_map, probs) ctx.fp8 = fp8 return unpermuted_output @@ -207,35 +195,33 @@ def backward( if not unpermuted_act_grad.is_contiguous(): unpermuted_act_grad = unpermuted_act_grad.contiguous() - fp8 = ctx.fp8 - if fp8: + if ctx.fp8: assert isinstance( unpermuted_act_grad, Float8Tensor ), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute." - fp8_dtype = unpermuted_act_grad._fp8_dtype + dtype = unpermuted_act_grad._fp8_dtype fp8_scale_inv = unpermuted_act_grad._scale_inv unpermuted_act_grad = unpermuted_act_grad._data + else: + dtype = TE_DType[unpermuted_act_grad.dtype] inp, row_id_map, probs = ctx.saved_tensors act_grad = None if ctx.needs_input_grad[0]: act_grad, prob_grad = tex.moe_unpermute_bwd( - unpermuted_act_grad, inp, ctx.dtype, row_id_map, probs + unpermuted_act_grad, inp, dtype, row_id_map, probs ) - if fp8: - act_grad = Float8Tensor( - data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv - ) - if not ctx.needs_input_grad[3]: + if ctx.fp8: + act_grad = Float8Tensor(data=act_grad, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv) + if not ctx.needs_input_grad[2]: prob_grad = None - return act_grad, None, None, prob_grad + return act_grad, None, prob_grad def moe_permute( inp: torch.Tensor, - dtype: tex.DType, indices: torch.Tensor, num_out_tokens: int = -1, max_token_num: int = -1, @@ -247,8 +233,6 @@ def moe_permute( ---------- inp: torch.Tensor Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. - dtype: tex.DType - Data type of the input tensor. indices: torch.Tensor The token to expert indices tensor of shape [num_tokens, topK] and dtype 'int32'. num_out_tokens: int, default = -1 @@ -259,12 +243,11 @@ def moe_permute( By default, set to '-1', meaning the calculation of the size of workspace is automatically taken over by the operator. """ - return _moe_permute.apply(inp, dtype, indices, num_out_tokens, max_token_num) + return _moe_permute.apply(inp, indices, num_out_tokens, max_token_num) def moe_unpermute( inp: torch.Tensor, - dtype: tex.DType, row_id_map: torch.Tensor, probs: torch.Tensor = None, ) -> torch.Tensor: @@ -276,8 +259,6 @@ def moe_unpermute( ---------- inp: torch.Tensor Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted. - dtype: tex.DType - Data type of the input tensor. row_id_map: torch.Tensor The tensor of a mapping table for sorted indices used to unpermute the tokens, which is the second output tensor of `Permute`. @@ -286,4 +267,4 @@ def moe_unpermute( the unpermuted tokens will be merged with their respective probabilities. By default, set to an empty tensor, which means that the tokens are directly merged by accumulation. """ - return _moe_unpermute.apply(inp, dtype, row_id_map, probs) + return _moe_unpermute.apply(inp, row_id_map, probs) From aecd5a8fae4cbb73c5fa53dc607c9c83ac3626d3 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 29 Aug 2024 22:39:59 -0700 Subject: [PATCH 18/50] [PyTorch] Fix FP8 logic related to FA2/FA3 (#1141) * fix FP8 logic when FA3 is not installed Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor tweak to make logic more explicit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * limit FA3 warning to Hopper and NVTE_FLASH_ATTN=1 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * prefer fused attn for FP8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_fused_attn.py | 16 +++---- transformer_engine/pytorch/attention.py | 46 +++++++++++++++------ 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index b3225a71ef..da26c7c42f 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -22,6 +22,7 @@ get_attention_backend, _flash_attn_2_plus, _flash_attn_2_3_plus, + _flash_attn_3_plus, check_set_window_size, AttentionParams, _attention_backends, @@ -135,7 +136,6 @@ def _get_attention_backends( os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_UNFUSED_ATTN"] = "1" - global _attention_backends _attention_backends["backend_selection_requires_update"] = True alibi_slopes_shape = None @@ -678,7 +678,6 @@ def _run_dot_product_attention( if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0" - global _attention_backends _attention_backends["backend_selection_requires_update"] = True # Create seqlens @@ -1167,7 +1166,6 @@ def _run_transformer_layer( os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" - global _attention_backends _attention_backends["backend_selection_requires_update"] = True # Create input tensor @@ -1352,8 +1350,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" config = model_configs_fp8_vs_f16[model] - global _attention_backends - if not is_training: + if _flash_attn_3_plus and not is_training: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True @@ -1379,7 +1376,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, rtol = 5e-1 rmse_tol = 0.15 logging.debug("========== {:^25s} ==========".format("forward output")) - if not is_training: + if _flash_attn_3_plus and not is_training: _error( flash_attn_fwd_fp8, fused_attn_fwd_f16, @@ -1527,8 +1524,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" - global _attention_backends - if not is_training: + if _flash_attn_3_plus and not is_training: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True @@ -1555,7 +1551,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): rmse_tol = 0.1 bwd_names = ["dq", "dk", "dv"] logging.debug("========== {:^25s} ==========".format("forward output")) - if not is_training: + if _flash_attn_3_plus and not is_training: _error( flash_attn_fwd_fp8, fused_attn_fwd_f16, @@ -1778,7 +1774,6 @@ def _run_custom_mha_fp8(dtype, config, backend): os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" - global _attention_backends _attention_backends["backend_selection_requires_update"] = True inp = 0.0001 * torch.randint( @@ -1833,7 +1828,6 @@ def _run_ref_mha_f16(dtype, config, backend): os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" - global _attention_backends _attention_backends["backend_selection_requires_update"] = True inp = torch.load("qkv.pt").to(device="cuda") diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 5c7f361c0a..49c142f763 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -74,6 +74,9 @@ from transformer_engine.pytorch.graph import is_graph_capturing +_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) +_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) +_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) _flash_attn_version = PkgVersion(get_pkg_version("flash-attn")) _flash_attn_version_required = PkgVersion("2.0.6") _flash_attn_max_version = PkgVersion("2.6.3") @@ -89,13 +92,14 @@ _flash_attn_v3_version = PkgVersion(get_pkg_version("flashattn-hopper")) _flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.6.1") except PackageNotFoundError: - warnings.warn( - "To use flash-attn v3, please use the following commands to install: \n" - """(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" \n""" - """(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` \n""" - """(3) mkdir -p $python_path/flashattn_hopper \n""" - """(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py""" - ) + if get_device_compute_capability() == (9, 0) and _NVTE_FLASH_ATTN: + warnings.warn( + "To use flash-attn v3, please use the following commands to install: \n" + """(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" \n""" + """(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` \n""" + """(3) mkdir -p $python_path/flashattn_hopper \n""" + """(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py""" + ) else: from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3 from flashattn_hopper.flash_attn_interface import ( @@ -137,10 +141,6 @@ _stream_handler = logging.StreamHandler() _stream_handler.setFormatter(_formatter) -_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) -_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) -_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) - _attention_backends = { "attention_params": None, "use_flash_attention": None, @@ -382,8 +382,13 @@ def get_attention_backend( # Filter: Execution type if fp8 and fp8_meta["recipe"].fp8_dpa: - if use_flash_attention and is_training: - logger.debug("Disabling FlashAttention as it does not support FP8 training") + if use_flash_attention and not _use_flash_attn_3: + logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8") + use_flash_attention = False + if use_flash_attention and _use_flash_attn_3 and is_training: + logger.debug( + "Disabling FlashAttention as FlashAttention 3 does not support FP8 training" + ) use_flash_attention = False if use_unfused_attention: logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8") @@ -826,6 +831,21 @@ def get_attention_backend( ) use_flash_attention = False + # Select FusedAttention for FP8 + # FA3 uses default scaling factors (i.e. 1) in FP8 execution, while FusedAttention takes + # scaling factors from `fp8_meta` and offers more accurate quantization/de-quantization + if ( + use_flash_attention + and use_fused_attention + and fused_attention_backend == FusedAttnBackend["FP8"] + and _use_flash_attn_3 + ): + logger.debug( + "Disabling FlashAttention 3 to give FusedAttention preference as FusedAttention " + "supports more accurate scaling factors in FP8 execution" + ) + use_flash_attention = False + # Selected backend if use_flash_attention: use_fused_attention = False From 9437ceb2b7947857c979d5a7a2ed60cd4e667a88 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Date: Thu, 29 Aug 2024 22:44:16 -0700 Subject: [PATCH 19/50] Fix QKV dtype in the bwd of FP8+CP (#1134) * fix qkv_dtype of FP8+CP Signed-off-by: Xiaowei Ren * config cp correction dtype of FP8+CP Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * code style change Signed-off-by: Xiaowei Ren * always do FP8 CP correction in FP32 Signed-off-by: Xiaowei Ren --------- Signed-off-by: Xiaowei Ren Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 49c142f763..27d7c0fdc4 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2261,8 +2261,9 @@ def backward(ctx, dout): if ctx.fp8: if ctx.use_fused_attention: + fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - fused_attn_qkv_dtype = fp8_dtype_backward + fused_attn_qkv_dtype = fp8_dtype_forward fused_attn_dqkv_dtype = fp8_dtype_backward fused_attn_backend = FusedAttnBackend["FP8"] dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device) @@ -2304,7 +2305,7 @@ def backward(ctx, dout): if ctx.use_fused_attention: fp8_meta_kwargs = {} fused_attn_qkv_dtype = TE_DType[q.dtype] - fused_attn_dqkv_dtype = TE_DType[q.dtype] + fused_attn_dqkv_dtype = TE_DType[dout.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] out = out.view(*q.shape) From 93f00a79933ed2260ee2442b602dc72d42019eed Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 3 Sep 2024 09:24:52 -0700 Subject: [PATCH 20/50] Improvements for building wheels (#1148) * Improvements for wheels Signed-off-by: Kirthi Shankar Sivamani * fix Signed-off-by: Kirthi Shankar Sivamani * Fixes for wheel build Signed-off-by: Kirthi Shankar Sivamani * Move package finder to common Signed-off-by: Kirthi Shankar Sivamani * format Signed-off-by: Kirthi Shankar Sivamani * Fixes Signed-off-by: Kirthi Shankar Sivamani * Lint Signed-off-by: Kirthi Shankar Sivamani * fix Signed-off-by: Kirthi Shankar Sivamani * FIx Signed-off-by: Kirthi Shankar Sivamani * Fix CI and distributed test Signed-off-by: Kirthi Shankar Sivamani * fix Signed-off-by: Kirthi Shankar Sivamani * fix paddle ci Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- build_tools/utils.py | 3 +- build_tools/wheel_utils/Dockerfile.aarch | 2 +- build_tools/wheel_utils/Dockerfile.x86 | 2 +- build_tools/wheel_utils/build_wheels.sh | 56 +++++++---- qa/L0_jax_wheel/test.sh | 26 +++-- qa/L0_paddle_wheel/test.sh | 27 ++++-- qa/L0_pytorch_wheel/test.sh | 26 +++-- qa/L1_pytorch_distributed_unittest/test.sh | 4 + setup.py | 108 ++++++++++++--------- transformer_engine/common/__init__.py | 11 +++ transformer_engine/jax/__init__.py | 35 ++++++- transformer_engine/paddle/__init__.py | 32 ++++++ transformer_engine/pytorch/__init__.py | 37 ++++++- 13 files changed, 279 insertions(+), 90 deletions(-) diff --git a/build_tools/utils.py b/build_tools/utils.py index 964a445bc4..d846b87f22 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -301,7 +301,7 @@ def install_and_import(package): globals()[main_package] = importlib.import_module(main_package) -def uninstall_te_fw_packages(): +def uninstall_te_wheel_packages(): subprocess.check_call( [ sys.executable, @@ -309,6 +309,7 @@ def uninstall_te_fw_packages(): "pip", "uninstall", "-y", + "transformer_engine_cu12", "transformer_engine_torch", "transformer_engine_paddle", "transformer_engine_jax", diff --git a/build_tools/wheel_utils/Dockerfile.aarch b/build_tools/wheel_utils/Dockerfile.aarch index a0bcd80347..7d839958cb 100644 --- a/build_tools/wheel_utils/Dockerfile.aarch +++ b/build_tools/wheel_utils/Dockerfile.aarch @@ -33,4 +33,4 @@ ENV CUDA_PATH=/usr/local/cuda ENV CUDADIR=/usr/local/cuda ENV NVTE_RELEASE_BUILD=1 -CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64", "true", "false", "false", "true"] +CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64", "true", "true", "false", "false", "false"] diff --git a/build_tools/wheel_utils/Dockerfile.x86 b/build_tools/wheel_utils/Dockerfile.x86 index 602d99ed4d..7dedf2a761 100644 --- a/build_tools/wheel_utils/Dockerfile.x86 +++ b/build_tools/wheel_utils/Dockerfile.x86 @@ -33,4 +33,4 @@ ENV CUDA_PATH=/usr/local/cuda ENV CUDADIR=/usr/local/cuda ENV NVTE_RELEASE_BUILD=1 -CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true"] +CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true", "true"] diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index 1896fc4e42..7682a2b6aa 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -5,10 +5,11 @@ set -e PLATFORM=${1:-manylinux_2_28_x86_64} -BUILD_COMMON=${2:-true} -BUILD_JAX=${3:-true} +BUILD_METAPACKAGE=${2:-true} +BUILD_COMMON=${3:-true} BUILD_PYTORCH=${4:-true} -BUILD_PADDLE=${5:-true} +BUILD_JAX=${5:-true} +BUILD_PADDLE=${6:-true} export NVTE_RELEASE_BUILD=1 export TARGET_BRANCH=${TARGET_BRANCH:-} @@ -20,12 +21,33 @@ cd /TransformerEngine git checkout $TARGET_BRANCH git submodule update --init --recursive +if $BUILD_METAPACKAGE ; then + cd /TransformerEngine + NVTE_BUILD_METAPACKAGE=1 /opt/python/cp310-cp310/bin/python setup.py bdist_wheel 2>&1 | tee /wheelhouse/logs/metapackage.txt + mv dist/* /wheelhouse/ +fi + if $BUILD_COMMON ; then + VERSION=`cat build_tools/VERSION.txt` + WHL_BASE="transformer_engine-${VERSION}" + + # Create the wheel. /opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt + + # Repack the wheel for cuda specific package, i.e. cu12. + /opt/python/cp38-cp38/bin/wheel unpack dist/* + # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). + sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" + /opt/python/cp38-cp38/bin/wheel pack ${WHL_BASE} + + # Rename the wheel to make it python version agnostic. whl_name=$(basename dist/*) IFS='-' read -ra whl_parts <<< "$whl_name" - whl_name_target="${whl_parts[0]}-${whl_parts[1]}-py3-none-${whl_parts[4]}" - mv dist/"$whl_name" /wheelhouse/"$whl_name_target" + whl_name_target="${whl_parts[0]}_cu12-${whl_parts[1]}-py3-none-${whl_parts[4]}" + rm -rf $WHL_BASE dist + mv *.whl /wheelhouse/"$whl_name_target" fi if $BUILD_PYTORCH ; then @@ -37,8 +59,8 @@ fi if $BUILD_JAX ; then cd /TransformerEngine/transformer_engine/jax - /opt/python/cp38-cp38/bin/pip install jax jaxlib - /opt/python/cp38-cp38/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt + /opt/python/cp310-cp310/bin/pip install "jax[cuda12_local]" jaxlib + /opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt cp dist/* /wheelhouse/ fi @@ -48,30 +70,30 @@ if $BUILD_PADDLE ; then dnf -y install libcudnn8-devel.x86_64 libcudnn8.x86_64 cd /TransformerEngine/transformer_engine/paddle - /opt/python/cp38-cp38/bin/pip install /wheelhouse/*.whl + /opt/python/cp38-cp38/bin/pip install /wheelhouse/*.whl --no-deps /opt/python/cp38-cp38/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp38.txt - /opt/python/cp38-cp38/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + /opt/python/cp38-cp38/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - /opt/python/cp39-cp39/bin/pip install /wheelhouse/*.whl + /opt/python/cp39-cp39/bin/pip install /wheelhouse/*.whl --no-deps /opt/python/cp39-cp39/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp39-cp39/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp39.txt - /opt/python/cp39-cp39/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + /opt/python/cp39-cp39/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - /opt/python/cp310-cp310/bin/pip install /wheelhouse/*.whl + /opt/python/cp310-cp310/bin/pip install /wheelhouse/*.whl --no-deps /opt/python/cp310-cp310/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp310.txt - /opt/python/cp310-cp310/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + /opt/python/cp310-cp310/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - /opt/python/cp311-cp311/bin/pip install /wheelhouse/*.whl + /opt/python/cp311-cp311/bin/pip install /wheelhouse/*.whl --no-deps /opt/python/cp311-cp311/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp311-cp311/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp311.txt - /opt/python/cp311-cp311/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + /opt/python/cp311-cp311/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - /opt/python/cp312-cp312/bin/pip install /wheelhouse/*.whl + /opt/python/cp312-cp312/bin/pip install /wheelhouse/*.whl --no-deps /opt/python/cp312-cp312/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp312-cp312/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp312.txt - /opt/python/cp312-cp312/bin/pip uninstall -y transformer-engine paddlepaddle-gpu + /opt/python/cp312-cp312/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu mv dist/* /wheelhouse/ fi diff --git a/qa/L0_jax_wheel/test.sh b/qa/L0_jax_wheel/test.sh index 109633495b..2c3b832933 100644 --- a/qa/L0_jax_wheel/test.sh +++ b/qa/L0_jax_wheel/test.sh @@ -6,16 +6,30 @@ set -e : "${TE_PATH:=/opt/transformerengine}" +pip install wheel + cd $TE_PATH -pip uninstall -y transformer-engine -export NVTE_RELEASE_BUILD=1 -python setup.py bdist_wheel +pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-jax + +VERSION=`cat $TE_PATH/build_tools/VERSION.txt` +WHL_BASE="transformer_engine-${VERSION}" + +# Core wheel. +NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel +wheel unpack dist/* +sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" +sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" +mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" +wheel pack ${WHL_BASE} +rm dist/*.whl +mv *.whl dist/ +NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel + cd transformer_engine/jax -python setup.py sdist +NVTE_RELEASE_BUILD=1 python setup.py sdist -export NVTE_RELEASE_BUILD=0 pip install dist/* cd $TE_PATH -pip install dist/* +pip install dist/*.whl --no-deps python $TE_PATH/tests/jax/test_sanity_import.py diff --git a/qa/L0_paddle_wheel/test.sh b/qa/L0_paddle_wheel/test.sh index e2d6d38dd4..30fbb1df1f 100644 --- a/qa/L0_paddle_wheel/test.sh +++ b/qa/L0_paddle_wheel/test.sh @@ -6,15 +6,28 @@ set -e : "${TE_PATH:=/opt/transformerengine}" +pip install wheel==0.44.0 pydantic + cd $TE_PATH -pip uninstall -y transformer-engine -export NVTE_RELEASE_BUILD=1 -python setup.py bdist_wheel -pip install dist/* -cd transformer_engine/paddle -python setup.py bdist_wheel +pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-paddle -export NVTE_RELEASE_BUILD=0 +VERSION=`cat $TE_PATH/build_tools/VERSION.txt` +WHL_BASE="transformer_engine-${VERSION}" + +# Core wheel. +NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel +wheel unpack dist/* +sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" +sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" +mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" +wheel pack ${WHL_BASE} +rm dist/*.whl +mv *.whl dist/ +NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel +pip install dist/*.whl --no-deps + +cd transformer_engine/paddle +NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel pip install dist/* python $TE_PATH/tests/paddle/test_sanity_import.py diff --git a/qa/L0_pytorch_wheel/test.sh b/qa/L0_pytorch_wheel/test.sh index e108e93cdb..fd8457c44b 100644 --- a/qa/L0_pytorch_wheel/test.sh +++ b/qa/L0_pytorch_wheel/test.sh @@ -6,16 +6,30 @@ set -e : "${TE_PATH:=/opt/transformerengine}" +pip install wheel + cd $TE_PATH -pip uninstall -y transformer-engine -export NVTE_RELEASE_BUILD=1 -python setup.py bdist_wheel +pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-torch + +VERSION=`cat $TE_PATH/build_tools/VERSION.txt` +WHL_BASE="transformer_engine-${VERSION}" + +# Core wheel. +NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel +wheel unpack dist/* +sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" +sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" +mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" +wheel pack ${WHL_BASE} +rm dist/*.whl +mv *.whl dist/ +NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel + cd transformer_engine/pytorch -python setup.py sdist +NVTE_RELEASE_BUILD=1 python setup.py sdist -export NVTE_RELEASE_BUILD=0 pip install dist/* cd $TE_PATH -pip install dist/* +pip install dist/*.whl --no-deps python $TE_PATH/tests/pytorch/test_sanity_import.py diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index fef48fd4b0..50394c33a9 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -4,6 +4,10 @@ set -e +# pkg_resources is deprecated in setuptools 70+ and the packaging submodule +# has been removed from it. This is a temporary fix until upstream MLM fix. +pip install setuptools==69.5.1 + : ${TE_PATH:=/opt/transformerengine} pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py diff --git a/setup.py b/setup.py index 4e5359e9c0..0b0639aea6 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ get_frameworks, install_and_import, remove_dups, - uninstall_te_fw_packages, + uninstall_te_wheel_packages, ) frameworks = get_frameworks() @@ -106,46 +106,69 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: if __name__ == "__main__": - # Dependencies - setup_requires, install_requires, test_requires = setup_requirements() - __version__ = te_version() - ext_modules = [setup_common_extension()] - if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): - # Remove residual FW packages since compiling from source - # results in a single binary with FW extensions included. - uninstall_te_fw_packages() - if "pytorch" in frameworks: - from build_tools.pytorch import setup_pytorch_extension - - ext_modules.append( - setup_pytorch_extension( - "transformer_engine/pytorch/csrc", - current_file_path / "transformer_engine" / "pytorch" / "csrc", - current_file_path / "transformer_engine", + with open("README.rst", encoding="utf-8") as f: + long_description = f.read() + + # Settings for building top level empty package for dependency management. + if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))): + assert bool( + int(os.getenv("NVTE_RELEASE_BUILD", "0")) + ), "NVTE_RELEASE_BUILD env must be set for metapackage build." + ext_modules = [] + cmdclass = {} + package_data = {} + include_package_data = False + setup_requires = [] + install_requires = ([f"transformer_engine_cu12=={__version__}"],) + extras_require = { + "pytorch": [f"transformer_engine_torch=={__version__}"], + "jax": [f"transformer_engine_jax=={__version__}"], + "paddle": [f"transformer_engine_paddle=={__version__}"], + } + else: + setup_requires, install_requires, test_requires = setup_requirements() + ext_modules = [setup_common_extension()] + cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist} + package_data = {"": ["VERSION.txt"]} + include_package_data = True + extras_require = {"test": test_requires} + + if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): + # Remove residual FW packages since compiling from source + # results in a single binary with FW extensions included. + uninstall_te_wheel_packages() + if "pytorch" in frameworks: + from build_tools.pytorch import setup_pytorch_extension + + ext_modules.append( + setup_pytorch_extension( + "transformer_engine/pytorch/csrc", + current_file_path / "transformer_engine" / "pytorch" / "csrc", + current_file_path / "transformer_engine", + ) ) - ) - if "jax" in frameworks: - from build_tools.jax import setup_jax_extension - - ext_modules.append( - setup_jax_extension( - "transformer_engine/jax/csrc", - current_file_path / "transformer_engine" / "jax" / "csrc", - current_file_path / "transformer_engine", + if "jax" in frameworks: + from build_tools.jax import setup_jax_extension + + ext_modules.append( + setup_jax_extension( + "transformer_engine/jax/csrc", + current_file_path / "transformer_engine" / "jax" / "csrc", + current_file_path / "transformer_engine", + ) ) - ) - if "paddle" in frameworks: - from build_tools.paddle import setup_paddle_extension - - ext_modules.append( - setup_paddle_extension( - "transformer_engine/paddle/csrc", - current_file_path / "transformer_engine" / "paddle" / "csrc", - current_file_path / "transformer_engine", + if "paddle" in frameworks: + from build_tools.paddle import setup_paddle_extension + + ext_modules.append( + setup_paddle_extension( + "transformer_engine/paddle/csrc", + current_file_path / "transformer_engine" / "paddle" / "csrc", + current_file_path / "transformer_engine", + ) ) - ) # Configure package setuptools.setup( @@ -158,13 +181,10 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: "transformer_engine/build_tools", ], ), - extras_require={ - "test": test_requires, - "pytorch": [f"transformer_engine_torch=={__version__}"], - "jax": [f"transformer_engine_jax=={__version__}"], - "paddle": [f"transformer_engine_paddle=={__version__}"], - }, + extras_require=extras_require, description="Transformer acceleration library", + long_description=long_description, + long_description_content_type="text/x-rst", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, python_requires=">=3.8, <3.13", @@ -178,6 +198,6 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: setup_requires=setup_requires, install_requires=install_requires, license_files=("LICENSE",), - include_package_data=True, - package_data={"": ["VERSION.txt"]}, + include_package_data=include_package_data, + package_data=package_data, ) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index f4eb2c419f..46cfa9176a 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -4,6 +4,7 @@ """FW agnostic user-end APIs""" +import sys import glob import sysconfig import subprocess @@ -15,6 +16,16 @@ import transformer_engine +def is_package_installed(package): + """Checks if a pip package is installed.""" + return ( + subprocess.run( + [sys.executable, "-m", "pip", "show", package], capture_output=True, check=False + ).returncode + == 0 + ) + + def get_te_path(): """Find Transformer Engine install path using pip""" return Path(transformer_engine.__path__[0]).parent diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index 3200c8a019..05adbd624c 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -5,21 +5,50 @@ # pylint: disable=wrong-import-position,wrong-import-order +import logging import ctypes +from importlib.metadata import version -from transformer_engine.common import get_te_path +from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import _get_sys_extension def _load_library(): """Load shared library with Transformer Engine C extensions""" + module_name = "transformer_engine_jax" + + if is_package_installed(module_name): + assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`." + assert is_package_installed( + "transformer_engine_cu12" + ), "Could not find `transformer-engine-cu12`." + assert ( + version(module_name) + == version("transformer-engine") + == version("transformer-engine-cu12") + ), ( + "TransformerEngine package version mismatch. Found" + f" {module_name} v{version(module_name)}, transformer-engine" + f" v{version('transformer-engine')}, and transformer-engine-cu12" + f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install" + " transformer-engine[jax]==VERSION'" + ) + + if is_package_installed("transformer-engine-cu12"): + if not is_package_installed(module_name): + logging.info( + "Could not find package %s. Install transformer-engine using 'pip" + " install transformer-engine[jax]==VERSION'", + module_name, + ) + extension = _get_sys_extension() try: so_dir = get_te_path() / "transformer_engine" - so_path = next(so_dir.glob(f"transformer_engine_jax.*.{extension}")) + so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) except StopIteration: so_dir = get_te_path() - so_path = next(so_dir.glob(f"transformer_engine_jax.*.{extension}")) + so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) diff --git a/transformer_engine/paddle/__init__.py b/transformer_engine/paddle/__init__.py index 62fa1fe626..50cf2186d6 100644 --- a/transformer_engine/paddle/__init__.py +++ b/transformer_engine/paddle/__init__.py @@ -6,9 +6,41 @@ # pylint: disable=wrong-import-position,wrong-import-order +import logging +from importlib.metadata import version + +from transformer_engine.common import is_package_installed + def _load_library(): """Load shared library with Transformer Engine C extensions""" + module_name = "transformer_engine_paddle" + + if is_package_installed(module_name): + assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`." + assert is_package_installed( + "transformer_engine_cu12" + ), "Could not find `transformer-engine-cu12`." + assert ( + version(module_name) + == version("transformer-engine") + == version("transformer-engine-cu12") + ), ( + "TransformerEngine package version mismatch. Found" + f" {module_name} v{version(module_name)}, transformer-engine" + f" v{version('transformer-engine')}, and transformer-engine-cu12" + f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install" + " transformer-engine[paddle]==VERSION'" + ) + + if is_package_installed("transformer-engine-cu12"): + if not is_package_installed(module_name): + logging.info( + "Could not find package %s. Install transformer-engine using 'pip" + " install transformer-engine[paddle]==VERSION'", + module_name, + ) + from transformer_engine import transformer_engine_paddle # pylint: disable=unused-import diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 1c755491b0..89b20002a7 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -6,25 +6,54 @@ # pylint: disable=wrong-import-position,wrong-import-order +import logging import importlib +import importlib.util import sys import torch +from importlib.metadata import version -from transformer_engine.common import get_te_path +from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import _get_sys_extension def _load_library(): """Load shared library with Transformer Engine C extensions""" + module_name = "transformer_engine_torch" + + if is_package_installed(module_name): + assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`." + assert is_package_installed( + "transformer_engine_cu12" + ), "Could not find `transformer-engine-cu12`." + assert ( + version(module_name) + == version("transformer-engine") + == version("transformer-engine-cu12") + ), ( + "TransformerEngine package version mismatch. Found" + f" {module_name} v{version(module_name)}, transformer-engine" + f" v{version('transformer-engine')}, and transformer-engine-cu12" + f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install" + " transformer-engine[pytorch]==VERSION'" + ) + + if is_package_installed("transformer-engine-cu12"): + if not is_package_installed(module_name): + logging.info( + "Could not find package %s. Install transformer-engine using 'pip" + " install transformer-engine[pytorch]==VERSION'", + module_name, + ) + extension = _get_sys_extension() try: so_dir = get_te_path() / "transformer_engine" - so_path = next(so_dir.glob(f"transformer_engine_torch.*.{extension}")) + so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) except StopIteration: so_dir = get_te_path() - so_path = next(so_dir.glob(f"transformer_engine_torch.*.{extension}")) + so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) - module_name = "transformer_engine_torch" spec = importlib.util.spec_from_file_location(module_name, so_path) solib = importlib.util.module_from_spec(spec) sys.modules[module_name] = solib From ddc5774d522133e27392b515fc77650f635a7b11 Mon Sep 17 00:00:00 2001 From: Ruibin Cheung Date: Wed, 4 Sep 2024 00:25:28 +0800 Subject: [PATCH 21/50] [PyTorch] Add contiguous check for `te_grouped_gemm` (#1146) [PyTorch] Add contiguous check for grouped gemm Signed-off-by: beinggod Co-authored-by: beinggod Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/csrc/extensions/gemm.cu | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cu b/transformer_engine/pytorch/csrc/extensions/gemm.cu index 01fb94cab4..7405914a0e 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cu +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cu @@ -115,6 +115,11 @@ void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int if (pre_gelu_out[i].data_ptr() != nullptr) pre_gelu_out[i].zero_(); continue; } + + NVTE_CHECK(A[i].is_contiguous(), "A[", i, "] must be contiguous."); + NVTE_CHECK(B[i].is_contiguous(), "B[", i, "] must be contiguous."); + NVTE_CHECK(D[i].is_contiguous(), "D[", i, "] must be contiguous."); + te_A.emplace_back(make_tensor( A[i].data_ptr(), {static_cast(A[i].size(0)), static_cast(A[i].size(1))}, A_type, nullptr, nullptr, getDataPtr(A_scale_inverse, A_offset + i))); From af9f2fae2206816575ee6f0cdd4ed310a8e086f0 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 3 Sep 2024 17:05:19 -0700 Subject: [PATCH 22/50] Bump actions/download-artifact from 3 to 4.1.7 in /.github/workflows (#1154) Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 3 to 4.1.7. - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](https://github.com/actions/download-artifact/compare/v3...v4.1.7) --- updated-dependencies: - dependency-name: actions/download-artifact dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/deploy_nightly_docs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/deploy_nightly_docs.yml b/.github/workflows/deploy_nightly_docs.yml index 3820dbde50..cd68019c8f 100644 --- a/.github/workflows/deploy_nightly_docs.yml +++ b/.github/workflows/deploy_nightly_docs.yml @@ -16,7 +16,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Download artifact - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4.1.7 with: name: "te_docs" path: "html" From 247850e8bd9f13ccce729fd8204066193d27a167 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Tue, 3 Sep 2024 22:04:51 -0700 Subject: [PATCH 23/50] Add user to TE CI (#1155) Signed-off-by: Tim Moon --- .github/workflows/trigger-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index cd47fa9a54..c66571b536 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -33,6 +33,7 @@ jobs: || github.actor == 'Oleg-Goncharov' || github.actor == 'phu0ngng' || github.actor == 'xrennvidia' + || github.actor == 'yaox12' ) steps: - name: Check if comment is issued by authorized person From 5fafeb0efef60d6f10574bb4366cdc5a8db7192d Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 5 Sep 2024 13:57:30 +0800 Subject: [PATCH 24/50] [PyTorch] FP8 MHA with RoPE and Miscellaneous Improvements (#1100) * fp8 mha with rope Signed-off-by: Xin Yao * avoid index select in cast ops Signed-off-by: Xin Yao * avoid index select in fused_attn_fwd Signed-off-by: Xin Yao * rename is_first_module_in_mha to fp8_output Signed-off-by: Xin Yao * resolve comments Signed-off-by: Xin Yao * resolve comments Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * move transpose to backward for fp8 input Signed-off-by: Xin Yao * fix ut Signed-off-by: Xin Yao * resolve comments Signed-off-by: Xin Yao * update argument list for CP Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix for FA3 Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove unnecessary copy of scale_inv Signed-off-by: Xin Yao * skip fp8 dpa/mha tests when fa3 is not available Signed-off-by: Xin Yao * fix a merge bug Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/fused_attn/test_fused_attn.py | 36 +- transformer_engine/pytorch/attention.py | 444 ++++++++++-------- .../pytorch/cpp_extensions/fused_attn.py | 84 +++- transformer_engine/pytorch/csrc/extensions.h | 47 +- .../pytorch/csrc/extensions/attention.cu | 81 ++-- .../pytorch/csrc/extensions/cast.cu | 36 +- .../pytorch/csrc/extensions/pybind.cpp | 12 +- transformer_engine/pytorch/csrc/ts_fp8_op.cpp | 8 +- .../pytorch/module/layernorm_linear.py | 6 +- transformer_engine/pytorch/module/linear.py | 24 +- 10 files changed, 491 insertions(+), 287 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index da26c7c42f..a1ebead04a 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -1344,19 +1344,22 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol): @pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16) @pytest.mark.parametrize("input_layernorm", [True, False]) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) +@pytest.mark.parametrize("RoPE", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) -def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, is_training): +def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training): os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" config = model_configs_fp8_vs_f16[model] if _flash_attn_3_plus and not is_training: + if RoPE: + pytest.skip("Flash Attention doesn't support FP8 MHA with RoPE.") os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16( - dtype, config, True, qkv_format, input_layernorm, is_training + dtype, config, True, qkv_format, input_layernorm, RoPE, is_training ) os.environ["NVTE_FLASH_ATTN"] = "0" @@ -1364,12 +1367,12 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( - dtype, config, True, qkv_format, input_layernorm, is_training + dtype, config, True, qkv_format, input_layernorm, RoPE, is_training ) logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False") fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( - dtype, config, False, qkv_format, input_layernorm, is_training + dtype, config, False, qkv_format, input_layernorm, RoPE, is_training ) atol = 5e-1 @@ -1410,7 +1413,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, ) -def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, is_training): +def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training): reset_rng_states() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) @@ -1429,6 +1432,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: ) with fp8_model_init(enabled=fp8_mha): + rotary_pos_emb = None + if RoPE: + PE = RotaryPositionEmbedding(dim=config.head_dim_qk) + rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda") mha = MultiheadAttention( hidden_size=config.hidden_size, num_attention_heads=config.num_heads, @@ -1489,6 +1496,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: checkpoint_core_attention=False, core_attention_bias_type=config.attn_bias_type, is_first_microbatch=None, + rotary_pos_emb=rotary_pos_emb, ) if is_training: out.backward(out_grad) @@ -1977,12 +1985,18 @@ def forward( None, None, None, - fp8_meta["scaling_fwd"].scale_inv[META_QKV], - fp8_meta["scaling_fwd"].scale_inv[META_S], - fp8_meta["scaling_fwd"].scale[META_S], - fp8_meta["scaling_fwd"].scale[META_O], - fp8_meta["scaling_fwd"].amax_history[0][META_S], - fp8_meta["scaling_fwd"].amax_history[0][META_O], + fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv + META_QKV, # d_scale_qkv_offset + fp8_meta["scaling_fwd"].scale_inv, # d_scale_s + META_S, # d_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_s + META_S, # q_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_o + META_O, # q_scale_o_offset + fp8_meta["scaling_fwd"].amax_history, # amax_s + META_S, # amax_s_offset + fp8_meta["scaling_fwd"].amax_history, # amax_o + META_O, # amax_o_offset attn_scale=None, dropout=p_dropout, fast_zero_fill=fast_zero_fill, diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 27d7c0fdc4..59bc26140d 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -38,8 +38,20 @@ AttnBiasType, AttnMaskType, FusedAttnBackend, + META_QKV, + META_DQKV, + META_O, + META_DO, + META_S, + META_DP, + META_O_CP, + META_DQKV_CP, +) +from transformer_engine.pytorch.fp8 import ( + FP8GlobalStateManager, + get_fp8_te_dtype, + get_fp8_torch_dtype, ) -from transformer_engine.pytorch.fp8 import get_fp8_te_dtype, get_fp8_torch_dtype from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.module import LayerNormLinear, Linear from transformer_engine.pytorch.module.base import TransformerEngineBaseModule @@ -120,15 +132,6 @@ from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd -META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT -META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1 -META_O = tex.FP8FwdTensors.GEMM2_INPUT -META_DO = tex.FP8BwdTensors.GRAD_INPUT2 -META_S = tex.FP8FwdTensors.GEMM3_OUTPUT -META_DP = tex.FP8BwdTensors.GRAD_INPUT3 -# repurpose some unused amax history buffers for partial results of CP fwd and bwd -META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT -META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1 # NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) @@ -1546,10 +1549,14 @@ def forward( for x in [k_f16, v_f16] ] fp8_meta_kwargs = {} - fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv[META_QKV] - fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv[META_S] - fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale[META_S] - fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale[META_O_CP] + fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv + fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV + fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv + fp8_meta_kwargs["d_scale_s_offset"] = META_S + fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale + fp8_meta_kwargs["q_scale_s_offset"] = META_S + fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale + fp8_meta_kwargs["q_scale_o_offset"] = META_O_CP amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) else: assert False, "FP8 is only supported with Fused Attention!" @@ -1601,8 +1608,10 @@ def forward( fp8_dtype_forward, ) if fp8 and use_fused_attention: - fp8_meta_kwargs["amax_s"] = amax_per_step[0][i] - fp8_meta_kwargs["amax_o"] = amax_per_step[1][i] + fp8_meta_kwargs["amax_s"] = amax_per_step + fp8_meta_kwargs["amax_s_offset"] = i + fp8_meta_kwargs["amax_o"] = amax_per_step + fp8_meta_kwargs["amax_o_offset"] = cp_size + i if causal: if i == 0: if pad_between_seqs_q: @@ -4153,9 +4162,8 @@ def run_iteratively(q, k, v): stride = q.stride() check_strides_qkv = all(stride == x.stride() for x in [q, k, v]) - stride = k.stride() - check_strides_kv = torch.equal( - torch.Tensor(stride[:-1]) / k.shape[-1], torch.Tensor(v.stride()[:-1]) / v.shape[-1] + check_strides_kv = tuple(sk / k.shape[-1] for sk in k.stride()[:-1]) == tuple( + sv / v.shape[-1] for sv in v.stride()[:-1] ) shape = q.shape @@ -4635,19 +4643,20 @@ def forward( fp8_meta, deterministic, ): + is_input_fp8 = False + is_output_fp8 = fp8_meta["recipe"].fp8_mha if fp8: - if fp8_meta["recipe"].fp8_mha: - assert isinstance(qkv, Float8Tensor), "qkv must be Float8Tensors for FP8 MHA." + is_input_fp8 = isinstance(qkv, Float8Tensor) + if is_input_fp8: fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv fused_attention_backend = FusedAttnBackend["FP8"] fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) # 1: qkv packed, 2: kv packed, 3: qkv separate qkv_group = len(qkv_layout.split("_")) - assert qkv_group == 1, ( - "qkv layout should conform to 3hd or h3d, e.g. sb3hd, but found" - f" {qkv_layout}." - ) - if fp8_meta["recipe"].fp8_mha: + assert ( + qkv_group == 1 + ), f"qkv layout should conform to 3hd or h3d, e.g. sb3hd, but found {qkv_layout}." + if is_input_fp8: qkv_fp8 = qkv._data else: qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) @@ -4663,12 +4672,18 @@ def forward( fused_attention_backend, attn_bias, cu_seqlens_padded, - fp8_meta["scaling_fwd"].scale_inv[META_QKV], - fp8_meta["scaling_fwd"].scale_inv[META_S], - fp8_meta["scaling_fwd"].scale[META_S], - fp8_meta["scaling_fwd"].scale[META_O], - fp8_meta["scaling_fwd"].amax_history[0][META_S], - fp8_meta["scaling_fwd"].amax_history[0][META_O], + fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv + META_QKV, # d_scale_qkv_offset + fp8_meta["scaling_fwd"].scale_inv, # d_scale_s + META_S, # d_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_s + META_S, # q_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_o + META_O, # q_scale_o_offset + fp8_meta["scaling_fwd"].amax_history, # amax_s + META_S, # amax_s_offset + fp8_meta["scaling_fwd"].amax_history, # amax_o + META_O, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -4678,7 +4693,7 @@ def forward( window_size, rng_gen, ) - if fp8_meta["recipe"].fp8_mha: + if is_output_fp8: out_ret = Float8Tensor( data=out_fp8, fp8_meta=fp8_meta, @@ -4696,22 +4711,24 @@ def forward( qkv_dtype, ).view(out_fp8.shape) out_save = out_ret - if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv = cast_from_fp8( - qkv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[qkv.dtype], - ).view(qkv.shape) - out_save = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) + if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + if is_input_fp8: + qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) + qkv = cast_from_fp8( + qkv_c._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[qkv.dtype], + ).view(qkv.shape) + if is_output_fp8: + out_save = cast_from_fp8( + out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), + fp8_meta["scaling_fwd"], + META_O, + fp8_dtype_forward, + qkv_dtype, + ).view(out_fp8.shape) fp8_tensors = ( qkv_fp8, out_fp8, @@ -4728,12 +4745,18 @@ def forward( fused_attention_backend, attn_bias, cu_seqlens_padded, - None, - None, - None, - None, - None, - None, + None, # d_scale_qkv + 0, # d_scale_qkv_offset + None, # d_scale_s + 0, # d_scale_s_offset + None, # q_scale_s + 0, # q_scale_s_offset + None, # q_scale_o + 0, # q_scale_o_offset + None, # amax_s + 0, # amax_s_offset + None, # amax_o + 0, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -4747,6 +4770,8 @@ def forward( out_save = out_ret ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ctx.is_input_fp8 = is_input_fp8 + ctx.is_output_fp8 = is_output_fp8 qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None) ctx.save_for_backward( *qkvo_tensors, cu_seqlens, cu_seqlens_padded, *fp8_tensors, *aux_ctx_tensors @@ -4771,7 +4796,7 @@ def forward( @staticmethod def backward(ctx, d_out): - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: assert isinstance( d_out, Float8Tensor ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." @@ -4828,7 +4853,7 @@ def backward(ctx, d_out): fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False ) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: d_out_fp8 = d_out ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv else: @@ -4868,7 +4893,7 @@ def backward(ctx, d_out): ctx.window_size, ctx.deterministic, ) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_input_fp8: dqkv = Float8Tensor( data=dqkv_fp8, fp8_meta=ctx.fp8_meta, @@ -5006,22 +5031,23 @@ def forward( fp8_meta, deterministic, ): + is_input_fp8 = False + is_output_fp8 = fp8_meta["recipe"].fp8_mha if fp8: - if fp8_meta["recipe"].fp8_mha: - assert isinstance(q, Float8Tensor) and isinstance( - kv, Float8Tensor - ), "q/kv must be Float8Tensors for FP8 MHA." + assert isinstance(kv, q.__class__), "q and kv must have the same type." + is_input_fp8 = isinstance(q, Float8Tensor) + if is_input_fp8: fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv fused_attention_backend = FusedAttnBackend["FP8"] fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if fp8_meta["recipe"].fp8_mha: + if is_input_fp8: q_fp8, kv_fp8 = q._data, kv._data else: # 1: qkv packed, 2: kv packed, 3: qkv separate qkv_group = len(qkv_layout.split("_")) assert qkv_group == 2, ( - "qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, " - f" but found {qkv_layout}." + "qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, " + f"but found {qkv_layout}." ) q_fp8 = cast_to_fp8(q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward).view( q.shape @@ -5043,12 +5069,18 @@ def forward( attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - fp8_meta["scaling_fwd"].scale_inv[META_QKV], - fp8_meta["scaling_fwd"].scale_inv[META_S], - fp8_meta["scaling_fwd"].scale[META_S], - fp8_meta["scaling_fwd"].scale[META_O], - fp8_meta["scaling_fwd"].amax_history[0][META_S], - fp8_meta["scaling_fwd"].amax_history[0][META_O], + fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv + META_QKV, # d_scale_qkv_offset + fp8_meta["scaling_fwd"].scale_inv, # d_scale_s + META_S, # d_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_s + META_S, # q_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_o + META_O, # q_scale_o_offset + fp8_meta["scaling_fwd"].amax_history, # amax_s + META_S, # amax_s_offset + fp8_meta["scaling_fwd"].amax_history, # amax_o + META_O, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -5058,7 +5090,7 @@ def forward( window_size, rng_gen, ) - if fp8_meta["recipe"].fp8_mha: + if is_output_fp8: out_ret = Float8Tensor( data=out_fp8, fp8_meta=fp8_meta, @@ -5076,25 +5108,31 @@ def forward( qkv_dtype, ).view(out_fp8.shape) out_save = out_ret - if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q = cast_from_fp8( - q._data, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward, TE_DType[q.dtype] - ).view(q.shape) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv = cast_from_fp8( - kv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[kv.dtype], - ).view(kv.shape) - out_save = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) + if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + if is_input_fp8: + q = cast_from_fp8( + q._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[q.dtype], + ).view(q.shape) + kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) + kv = cast_from_fp8( + kv_c._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[kv.dtype], + ).view(kv.shape) + if is_output_fp8: + out_save = cast_from_fp8( + out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), + fp8_meta["scaling_fwd"], + META_O, + fp8_dtype_forward, + qkv_dtype, + ).view(out_fp8.shape) fp8_tensors = ( q_fp8, kv_fp8, @@ -5116,12 +5154,18 @@ def forward( attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - None, - None, - None, - None, - None, - None, + None, # d_scale_qkv + 0, # d_scale_qkv_offset + None, # d_scale_s + 0, # d_scale_s_offset + None, # q_scale_s + 0, # q_scale_s_offset + None, # q_scale_o + 0, # q_scale_o_offset + None, # amax_s + 0, # amax_s_offset + None, # amax_o + 0, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -5135,6 +5179,8 @@ def forward( fp8_tensors = (None, None, None, None, None) ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ctx.is_input_fp8 = is_input_fp8 + ctx.is_output_fp8 = is_output_fp8 qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None) ctx.save_for_backward( *qkvo_tensors, @@ -5166,7 +5212,7 @@ def forward( @staticmethod def backward(ctx, d_out): - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: assert isinstance( d_out, Float8Tensor ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." @@ -5227,7 +5273,7 @@ def backward(ctx, d_out): fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False ) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: d_out_fp8 = d_out ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv else: @@ -5271,7 +5317,7 @@ def backward(ctx, d_out): ctx.window_size, ctx.deterministic, ) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_input_fp8: dq = Float8Tensor( data=dq_fp8, fp8_meta=ctx.fp8_meta, @@ -5437,15 +5483,16 @@ def forward( fp8_meta, deterministic, ): + is_input_fp8 = False + is_output_fp8 = fp8_meta["recipe"].fp8_mha if fp8: fused_attention_backend = FusedAttnBackend["FP8"] fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if fp8_meta["recipe"].fp8_mha: - assert ( - isinstance(q, Float8Tensor) - and isinstance(k, Float8Tensor) - and isinstance(v, Float8Tensor) - ), "q/k/v must be Float8Tensors for FP8 MHA." + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, and v must have the same type." + is_input_fp8 = isinstance(q, Float8Tensor) + if is_input_fp8: fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data else: @@ -5496,12 +5543,18 @@ def forward( attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - fp8_meta["scaling_fwd"].scale_inv[META_QKV], - fp8_meta["scaling_fwd"].scale_inv[META_S], - fp8_meta["scaling_fwd"].scale[META_S], - fp8_meta["scaling_fwd"].scale[META_O], - fp8_meta["scaling_fwd"].amax_history[0][META_S], - fp8_meta["scaling_fwd"].amax_history[0][META_O], + fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv + META_QKV, # d_scale_qkv_offset + fp8_meta["scaling_fwd"].scale_inv, # d_scale_s + META_S, # d_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_s + META_S, # q_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_o + META_O, # q_scale_o_offset + fp8_meta["scaling_fwd"].amax_history, # amax_s + META_S, # amax_s_offset + fp8_meta["scaling_fwd"].amax_history, # amax_o + META_O, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -5511,7 +5564,7 @@ def forward( window_size, rng_gen, ) - if fp8_meta["recipe"].fp8_mha: + if is_output_fp8: out_ret = Float8Tensor( data=out_fp8, fp8_meta=fp8_meta, @@ -5530,71 +5583,73 @@ def forward( ).view(out_fp8.shape) out_save = out_ret - if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.split("_")) - if qkv_group == 1: - dim = qkv_layout.find("3") - qkv = _combine_tensors([q, k, v], dim) - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_no_fp8 = cast_from_fp8( - qkv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[qkv.dtype], - ).view(qkv.shape) - q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1]) - q, k, v = [x.squeeze(dim) for x in [q, k, v]] - if qkv_group == 2: - q = cast_from_fp8( - q._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[q.dtype], - ).view(q.shape) - dim = qkv_layout.split("_")[1].find("2") - kv = _combine_tensors([k, v], dim) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_no_fp8 = cast_from_fp8( - kv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[kv.dtype], - ).view(kv.shape) - k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1]) - k, v = [x.squeeze(dim) for x in [k, v]] - if qkv_group == 3: - q = cast_from_fp8( - q._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[q.dtype], - ).view(q.shape) - k = cast_from_fp8( - k._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[k.dtype], - ).view(k.shape) - v = cast_from_fp8( - v._data, + if is_input_fp8: + qkv_group = len(qkv_layout.split("_")) + if qkv_group == 1: + dim = qkv_layout.find("3") + qkv = _combine_tensors([q, k, v], dim) + qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) + qkv_no_fp8 = cast_from_fp8( + qkv_c._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[qkv.dtype], + ).view(qkv.shape) + q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1]) + q, k, v = [x.squeeze(dim) for x in [q, k, v]] + if qkv_group == 2: + q = cast_from_fp8( + q._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[q.dtype], + ).view(q.shape) + dim = qkv_layout.split("_")[1].find("2") + kv = _combine_tensors([k, v], dim) + kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) + kv_no_fp8 = cast_from_fp8( + kv_c._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[kv.dtype], + ).view(kv.shape) + k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1]) + k, v = [x.squeeze(dim) for x in [k, v]] + if qkv_group == 3: + q = cast_from_fp8( + q._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[q.dtype], + ).view(q.shape) + k = cast_from_fp8( + k._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[k.dtype], + ).view(k.shape) + v = cast_from_fp8( + v._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[v.dtype], + ).view(v.shape) + if is_output_fp8: + out_save = cast_from_fp8( + out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), fp8_meta["scaling_fwd"], - META_QKV, + META_O, fp8_dtype_forward, - TE_DType[v.dtype], - ).view(v.shape) - out_save = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) + qkv_dtype, + ).view(out_fp8.shape) fp8_tensors = ( q_fp8, @@ -5619,12 +5674,18 @@ def forward( attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - None, - None, - None, - None, - None, - None, + None, # d_scale_qkv + 0, # d_scale_qkv_offset + None, # d_scale_s + 0, # d_scale_s_offset + None, # q_scale_s + 0, # q_scale_s_offset + None, # q_scale_o + 0, # q_scale_o_offset + None, # amax_s + 0, # amax_s_offset + None, # amax_o + 0, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -5647,6 +5708,8 @@ def forward( tensor.activation_offloading = True ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ctx.is_input_fp8 = is_input_fp8 + ctx.is_output_fp8 = is_output_fp8 qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) ctx.save_for_backward( *qkvo_tensors, @@ -5678,7 +5741,7 @@ def forward( @staticmethod def backward(ctx, d_out): - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: assert isinstance( d_out, Float8Tensor ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." @@ -5743,7 +5806,7 @@ def backward(ctx, d_out): fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False ) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: d_out_fp8 = d_out ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv else: @@ -5789,7 +5852,7 @@ def backward(ctx, d_out): ctx.deterministic, ) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_input_fp8: dq = Float8Tensor( data=dq_fp8, fp8_meta=ctx.fp8_meta, @@ -7719,12 +7782,18 @@ def forward( # Query, Key, and Value # ====================== + fp8_mha = ( + FP8GlobalStateManager.is_fp8_enabled() + and FP8GlobalStateManager.get_fp8_recipe().fp8_mha + ) + if self.attention_type == "self": # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn] if self.input_layernorm: layernorm_qkv_outputs = self.layernorm_qkv( hidden_states, is_first_microbatch=is_first_microbatch, + fp8_output=fp8_mha and rotary_pos_emb is None, ) if self.return_layernorm_output: mixed_x_layer, layernorm_output = layernorm_qkv_outputs @@ -7734,7 +7803,7 @@ def forward( mixed_x_layer = self.qkv( hidden_states, is_first_microbatch=is_first_microbatch, - is_first_module_in_mha=True, # specific to FP8 MHA + fp8_output=fp8_mha and rotary_pos_emb is None, ) num_queries_per_key_value = ( @@ -7795,7 +7864,7 @@ def forward( mixed_kv_layer = self.key_value( encoder_output, is_first_microbatch=is_first_microbatch, - is_first_module_in_mha=True, # specific to FP8 MHA + fp8_output=fp8_mha and rotary_pos_emb is None, ) if self.qkv_weight_interleaved: @@ -7845,6 +7914,7 @@ def forward( layernorm_query_outputs = self.layernorm_query( hidden_states, is_first_microbatch=is_first_microbatch, + fp8_output=fp8_mha and rotary_pos_emb is None, ) if self.return_layernorm_output: query_layer, layernorm_output = layernorm_query_outputs @@ -7854,7 +7924,7 @@ def forward( query_layer = self.query_layer( hidden_states, is_first_microbatch=is_first_microbatch, - is_first_module_in_mha=True, # specific to FP8 MHA + fp8_output=fp8_mha and rotary_pos_emb is None, ) # [sq, b, hp] --> [sq, b, np, hn] diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index d0ba644621..cd0ecbaa6c 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -78,6 +78,16 @@ BACKEND_F16m512_FP8_THREADS_PER_CTA = 128 BACKEND_F16arb_ELTS_PER_THREADS = 16 +META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT +META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1 +META_O = tex.FP8FwdTensors.GEMM2_INPUT +META_DO = tex.FP8BwdTensors.GRAD_INPUT2 +META_S = tex.FP8FwdTensors.GEMM3_OUTPUT +META_DP = tex.FP8BwdTensors.GRAD_INPUT3 +# repurpose some unused amax history buffers for partial results of CP fwd and bwd +META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT +META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1 + def fused_attn_fwd_qkvpacked( is_training: bool, @@ -89,11 +99,17 @@ def fused_attn_fwd_qkvpacked( attn_bias: torch.Tensor = None, cu_seqlens_padded: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, + d_scale_qkv_offset: int = META_QKV, d_scale_s: torch.Tensor = None, + d_scale_s_offset: int = META_S, q_scale_s: torch.Tensor = None, + q_scale_s_offset: int = META_S, q_scale_o: torch.Tensor = None, + q_scale_o_offset: int = META_O, amax_s: torch.Tensor = None, + amax_s_offset: int = META_S, amax_o: torch.Tensor = None, + amax_o_offset: int = META_O, attn_scale: float = None, dropout: float = 0.0, fast_zero_fill: bool = True, @@ -128,16 +144,28 @@ def fused_attn_fwd_qkvpacked( cumulative sequence offsets for QKV; shape [batch_size + 1] d_scale_qkv: torch.Tensor, default = None input tensor for the dequantization of QKV in FP8 computations + d_scale_qkv_offset: int, default = META_QKV + offset in d_scale_qkv for QKV d_scale_s: torch.Tensor, default = None input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) + d_scale_s_offset: int, default = META_S + offset in d_scale_s for S q_scale_s: torch.Tensor, default = None input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) + q_scale_s_offset: int, default = META_S + offset in q_scale_s for S q_scale_o: torch.Tensor, default = None input tensor for the quantization of O in FP8 computations + q_scale_o_offset: int, default = META_O + offset in q_scale_o for O amax_s: torch.Tensor, default = None output tensor, amax of S, used by the next iteration in FP8 computations + amax_s_offset: int, default = META_S + offset in amax_s for S amax_o: torch.Tensor, default = None output tensor, amax of O, used by the next iteration in FP8 computations + amax_o_offset: int, default = META_O + offset in amax_o for O attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; if None, use 1.0/sqrt(head_dim_qk) as the default @@ -248,11 +276,17 @@ def fused_attn_fwd_qkvpacked( qkv_dtype, cu_seqlens_padded, d_scale_qkv, + d_scale_qkv_offset, d_scale_s, + d_scale_s_offset, q_scale_s, + q_scale_s_offset, q_scale_o, + q_scale_o_offset, amax_s, + amax_s_offset, amax_o, + amax_o_offset, attn_bias, rng_gen, rng_elts_per_thread, @@ -448,11 +482,17 @@ def fused_attn_fwd_kvpacked( cu_seqlens_q_padded: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, + d_scale_qkv_offset: int = META_QKV, d_scale_s: torch.Tensor = None, + d_scale_s_offset: int = META_S, q_scale_s: torch.Tensor = None, + q_scale_s_offset: int = META_S, q_scale_o: torch.Tensor = None, + q_scale_o_offset: int = META_O, amax_s: torch.Tensor = None, + amax_s_offset: int = META_S, amax_o: torch.Tensor = None, + amax_o_offset: int = META_O, attn_scale: float = None, dropout: float = 0.0, fast_zero_fill: bool = True, @@ -496,16 +536,28 @@ def fused_attn_fwd_kvpacked( cumulative sequence offsets for KV; shape [batch_size + 1] d_scale_qkv: torch.Tensor, default = None input tensor for the dequantization of QKV in FP8 computations + d_scale_qkv_offset: int, default = META_QKV + offset in d_scale_qkv for QKV d_scale_s: torch.Tensor, default = None input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) + d_scale_s_offset: int, default = META_S + offset in d_scale_s for S q_scale_s: torch.Tensor, default = None input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) + q_scale_s_offset: int, default = META_S + offset in q_scale_s for S q_scale_o: torch.Tensor, default = None input tensor for the quantization of O in FP8 computations + q_scale_o_offset: int, default = META_O + offset in q_scale_o for O amax_s: torch.Tensor, default = None output tensor, amax of S, used by the next iteration in FP8 computations + amax_s_offset: int, default = META_S + offset in amax_s for S amax_o: torch.Tensor, default = None output tensor, amax of O, used by the next iteration in FP8 computations + amax_o_offset: int, default = META_O + offset in amax_o for O attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; if None, use 1.0/sqrt(head_dim_qk) as the default @@ -621,11 +673,17 @@ def fused_attn_fwd_kvpacked( cu_seqlens_q_padded, cu_seqlens_kv_padded, d_scale_qkv, + d_scale_qkv_offset, d_scale_s, + d_scale_s_offset, q_scale_s, + q_scale_s_offset, q_scale_o, + q_scale_o_offset, amax_s, + amax_s_offset, amax_o, + amax_o_offset, attn_bias, rng_gen, rng_elts_per_thread, @@ -843,11 +901,17 @@ def fused_attn_fwd( cu_seqlens_q_padded: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, + d_scale_qkv_offset: int = META_QKV, d_scale_s: torch.Tensor = None, + d_scale_s_offset: int = META_S, q_scale_s: torch.Tensor = None, + q_scale_s_offset: int = META_S, q_scale_o: torch.Tensor = None, + q_scale_o_offset: int = META_O, amax_s: torch.Tensor = None, + amax_s_offset: int = META_S, amax_o: torch.Tensor = None, + amax_o_offset: int = META_O, attn_scale: float = None, dropout: float = 0.0, fast_zero_fill: bool = True, @@ -894,17 +958,29 @@ def fused_attn_fwd( cu_seqlens_kv_padded: torch.Tensor, default = None cumulative sequence offsets for KV; shape [batch_size + 1] d_scale_qkv: torch.Tensor, default = None - input tensor for the dequantization of Q, K and V in FP8 computations + input tensor for the dequantization of QKV in FP8 computations + d_scale_qkv_offset: int, default = META_QKV + offset in d_scale_qkv for QKV d_scale_s: torch.Tensor, default = None input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) + d_scale_s_offset: int, default = META_S + offset in d_scale_s for S q_scale_s: torch.Tensor, default = None input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) + q_scale_s_offset: int, default = META_S + offset in q_scale_s for S q_scale_o: torch.Tensor, default = None input tensor for the quantization of O in FP8 computations + q_scale_o_offset: int, default = META_O + offset in q_scale_o for O amax_s: torch.Tensor, default = None output tensor, amax of S, used by the next iteration in FP8 computations + amax_s_offset: int, default = META_S + offset in amax_s for S amax_o: torch.Tensor, default = None output tensor, amax of O, used by the next iteration in FP8 computations + amax_o_offset: int, default = META_O + offset in amax_o for O attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; if None, use 1.0/sqrt(head_dim_qk) as the default @@ -1023,11 +1099,17 @@ def fused_attn_fwd( cu_seqlens_q_padded, cu_seqlens_kv_padded, d_scale_qkv, + d_scale_qkv_offset, d_scale_s, + d_scale_s_offset, q_scale_s, + q_scale_s_offset, q_scale_o, + q_scale_o_offset, amax_s, + amax_s_offset, amax_o, + amax_o_offset, attn_bias, rng_gen, rng_elts_per_thread, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 1a6f5f157e..45ef9951d7 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -48,11 +48,13 @@ std::vector fused_attn_fwd_qkvpacked( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional scale_S, const c10::optional scale_O, - c10::optional amax_S, c10::optional amax_O, - const c10::optional Bias, const c10::optional rng_gen, - size_t rng_elts_per_thread); + const c10::optional descale_QKV, const int descale_QKV_offset, + const c10::optional descale_S, const int descale_S_offset, + const c10::optional scale_S, const int scale_S_offset, + const c10::optional scale_O, const int scale_O_offset, + c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, + const int amax_O_offset, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread); std::vector fused_attn_bwd_qkvpacked( size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, @@ -75,11 +77,13 @@ std::vector fused_attn_fwd_kvpacked( const at::Tensor KV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional scale_S, const c10::optional scale_O, - c10::optional amax_S, c10::optional amax_O, - const c10::optional Bias, const c10::optional rng_gen, - size_t rng_elts_per_thread); + const c10::optional descale_QKV, const int descale_QKV_offset, + const c10::optional descale_S, const int descale_S_offset, + const c10::optional scale_S, const int scale_S_offset, + const c10::optional scale_O, const int scale_O_offset, + c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, + const int amax_O_offset, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread); std::vector fused_attn_bwd_kvpacked( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, @@ -104,11 +108,13 @@ std::vector fused_attn_fwd( const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional scale_S, const c10::optional scale_O, - c10::optional amax_S, c10::optional amax_O, - const c10::optional Bias, const c10::optional rng_gen, - size_t rng_elts_per_thread); + const c10::optional descale_QKV, const int descale_QKV_offset, + const c10::optional descale_S, const int descale_S_offset, + const c10::optional scale_S, const int scale_S_offset, + const c10::optional scale_O, const int scale_O_offset, + c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, + const int amax_O_offset, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread); std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, @@ -335,13 +341,18 @@ at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, fl **************************************************************************************************/ at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax, - at::Tensor scale_inv, transformer_engine::DType otype); + at::Tensor scale_inv, transformer_engine::DType otype, + const int scale_offset = 0, const int amax_offset = 0, + const int scale_inv_offset = 0); void cast_to_fp8_noalloc(const at::Tensor &input, const at::Tensor &scale, at::Tensor output, - at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype); + at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, + const int scale_offset = 0, const int amax_offset = 0, + const int scale_inv_offset = 0); at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv, - transformer_engine::DType itype, transformer_engine::DType otype); + transformer_engine::DType itype, transformer_engine::DType otype, + const int scale_inv_offset = 0); /*************************************************************************************************** * Softmax diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 50eb7b830f..fb1fc97a33 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -83,11 +83,13 @@ std::vector fused_attn_fwd_qkvpacked( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional scale_S, const c10::optional scale_O, - c10::optional amax_S, c10::optional amax_O, - const c10::optional Bias, const c10::optional rng_gen, - size_t rng_elts_per_thread) { + const c10::optional descale_QKV, const int descale_QKV_offset, + const c10::optional descale_S, const int descale_S_offset, + const c10::optional scale_S, const int scale_S_offset, + const c10::optional scale_O, const int scale_O_offset, + c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, + const int amax_O_offset, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread) { using namespace transformer_engine; auto qkv_sizes = QKV.sizes().vec(); @@ -122,11 +124,14 @@ std::vector fused_attn_fwd_qkvpacked( NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); } te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), - scale_O.value().data_ptr(), nullptr); + getDataPtr(descale_QKV.value(), descale_QKV_offset)); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, + getDataPtr(amax_S.value(), amax_S_offset), + getDataPtr(scale_S.value(), scale_S_offset), + getDataPtr(descale_S.value(), descale_S_offset)); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, + getDataPtr(amax_O.value(), amax_O_offset), + getDataPtr(scale_O.value(), scale_O_offset), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { O.fill_(0); @@ -393,11 +398,13 @@ std::vector fused_attn_fwd_kvpacked( const at::Tensor KV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional scale_S, const c10::optional scale_O, - c10::optional amax_S, c10::optional amax_O, - const c10::optional Bias, const c10::optional rng_gen, - size_t rng_elts_per_thread) { + const c10::optional descale_QKV, const int descale_QKV_offset, + const c10::optional descale_S, const int descale_S_offset, + const c10::optional scale_S, const int scale_S_offset, + const c10::optional scale_O, const int scale_O_offset, + c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, + const int amax_O_offset, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread) { using namespace transformer_engine; auto q_sizes = Q.sizes().vec(); @@ -429,13 +436,16 @@ std::vector fused_attn_fwd_kvpacked( NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); } te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); + getDataPtr(descale_QKV.value(), descale_QKV_offset)); te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), - scale_O.value().data_ptr(), nullptr); + getDataPtr(descale_QKV.value(), descale_QKV_offset)); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, + getDataPtr(amax_S.value(), amax_S_offset), + getDataPtr(scale_S.value(), scale_S_offset), + getDataPtr(descale_S.value(), descale_S_offset)); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, + getDataPtr(amax_O.value(), amax_O_offset), + getDataPtr(scale_O.value(), scale_O_offset), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { O.fill_(0); @@ -747,11 +757,13 @@ std::vector fused_attn_fwd( const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional scale_S, const c10::optional scale_O, - c10::optional amax_S, c10::optional amax_O, - const c10::optional Bias, const c10::optional rng_gen, - size_t rng_elts_per_thread) { + const c10::optional descale_QKV, const int descale_QKV_offset, + const c10::optional descale_S, const int descale_S_offset, + const c10::optional scale_S, const int scale_S_offset, + const c10::optional scale_O, const int scale_O_offset, + c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, + const int amax_O_offset, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread) { using namespace transformer_engine; auto q_sizes = Q.sizes().vec(); @@ -788,15 +800,18 @@ std::vector fused_attn_fwd( NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); } te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); + getDataPtr(descale_QKV.value(), descale_QKV_offset)); te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); + getDataPtr(descale_QKV.value(), descale_QKV_offset)); te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, amax_O.value().data_ptr(), - scale_O.value().data_ptr(), nullptr); + getDataPtr(descale_QKV.value(), descale_QKV_offset)); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, + getDataPtr(amax_S.value(), amax_S_offset), + getDataPtr(scale_S.value(), scale_S_offset), + getDataPtr(descale_S.value(), descale_S_offset)); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, + getDataPtr(amax_O.value(), amax_O_offset), + getDataPtr(scale_O.value(), scale_O_offset), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { O.fill_(0); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cu b/transformer_engine/pytorch/csrc/extensions/cast.cu index c783c9d988..47f5825866 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cu +++ b/transformer_engine/pytorch/csrc/extensions/cast.cu @@ -6,8 +6,9 @@ #include "extensions.h" -at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax, - at::Tensor scale_inv, transformer_engine::DType otype) { +at::Tensor cast_to_fp8(const at::Tensor& input, const at::Tensor& scale, at::Tensor amax, + at::Tensor scale_inv, transformer_engine::DType otype, + const int scale_offset, const int amax_offset, const int scale_inv_offset) { using namespace transformer_engine; auto input_shape = input.sizes().vec(); std::vector shape{input_shape.begin(), input_shape.end()}; @@ -16,32 +17,45 @@ at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Ten if (input.numel() == 0) return output; + // Get pointers for FP8 scale, amax, scale-inverse + void* scale_dptr = getDataPtr(scale, scale_offset); + void* amax_dptr = getDataPtr(amax, amax_offset); + void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + auto input_cu = makeTransformerEngineTensor(input); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype, amax_dptr, + scale_dptr, scale_inv_dptr); nvte_fp8_quantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return output; } -void cast_to_fp8_noalloc(const at::Tensor &input, const at::Tensor &scale, at::Tensor output, - at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype) { +void cast_to_fp8_noalloc(const at::Tensor& input, const at::Tensor& scale, at::Tensor output, + at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, + const int scale_offset, const int amax_offset, + const int scale_inv_offset) { using namespace transformer_engine; size_t N = static_cast(input.size(0)); size_t H = static_cast(input.size(1)); + // Get pointers for FP8 scale, amax, scale-inverse + void* scale_dptr = getDataPtr(scale, scale_offset); + void* amax_dptr = getDataPtr(amax, amax_offset); + void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + auto input_cu = makeTransformerEngineTensor(input); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype, amax_dptr, + scale_dptr, scale_inv_dptr); nvte_fp8_quantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return; } -at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv, - transformer_engine::DType itype, transformer_engine::DType otype) { +at::Tensor cast_from_fp8(const at::Tensor& input, const at::Tensor& scale_inv, + transformer_engine::DType itype, transformer_engine::DType otype, + const int scale_inv_offset) { using namespace transformer_engine; auto input_shape = input.sizes().vec(); std::vector shape{input_shape.begin(), input_shape.end()}; @@ -49,7 +63,7 @@ at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv, auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); auto input_cu = makeTransformerEngineTensor(input.data_ptr(), shape, itype, nullptr, nullptr, - scale_inv.data_ptr()); + getDataPtr(scale_inv, scale_inv_offset)); auto output_cu = makeTransformerEngineTensor(output); nvte_fp8_dequantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index f903a1c35b..dc82b6e2df 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -93,10 +93,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc, "Fused Multi-tensor Cast + Transpose with allocating output tensors", py::call_guard()); - m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8", py::call_guard()); + m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8", py::call_guard(), + py::arg("input"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), + py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8", - py::call_guard()); - m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8", py::call_guard()); + py::call_guard(), py::arg("input"), py::arg("scale"), + py::arg("output"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), + py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); + m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8", py::call_guard(), + py::arg("input"), py::arg("scale_inv"), py::arg("itype"), py::arg("otype"), + py::arg("scale_inv_offset") = 0); m.def("te_gemm", &te_gemm, "CublasLt GEMM"); /// TODO Think m.def("te_grouped_gemm", &te_grouped_gemm, "Grouped GEMM"); m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked, diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp index 8515092ae0..8c480e8343 100644 --- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp +++ b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp @@ -26,7 +26,7 @@ at::Tensor cast_to_fp8_ts(const at::Tensor &input, const at::Tensor &scale, at:: at::Tensor scale_inv, int64_t fp8_tensor, int64_t otype) { transformer_engine::DType otype_arg = reverse_map_dtype(otype); at::Tensor output = - cast_to_fp8(input, scale[fp8_tensor], amax[0][fp8_tensor], scale_inv[fp8_tensor], otype_arg); + cast_to_fp8(input, scale, amax, scale_inv, otype_arg, fp8_tensor, fp8_tensor, fp8_tensor); return output; } @@ -34,8 +34,8 @@ at::Tensor cast_to_fp8_noalloc_ts(const at::Tensor &input, const at::Tensor &sca at::Tensor output, at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor, int64_t otype) { transformer_engine::DType otype_arg = reverse_map_dtype(otype); - cast_to_fp8_noalloc(input, scale[fp8_tensor], output, amax[0][fp8_tensor], scale_inv[fp8_tensor], - otype_arg); + cast_to_fp8_noalloc(input, scale, output, amax, scale_inv, otype_arg, fp8_tensor, fp8_tensor, + fp8_tensor); return output; } @@ -43,7 +43,7 @@ at::Tensor cast_from_fp8_ts(const at::Tensor &input, const at::Tensor &scale_inv int64_t fp8_tensor, int64_t itype, int64_t otype) { transformer_engine::DType itype_arg = reverse_map_dtype(itype); transformer_engine::DType otype_arg = reverse_map_dtype(otype); - at::Tensor output = cast_from_fp8(input, scale_inv[fp8_tensor], itype_arg, otype_arg); + at::Tensor output = cast_from_fp8(input, scale_inv, itype_arg, otype_arg, fp8_tensor); return output; } diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index d6045d8e77..9586d6d345 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -91,6 +91,7 @@ def forward( ub_overlap_rs_dgrad: bool, ub_overlap_ag: bool, ub_name: str, + fp8_output: bool, fsdp_group: Union[dist_group_type, None], ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible @@ -220,7 +221,7 @@ def forward( if is_in_onnx_export_mode(): ln_out_scale_inv.fill_(ln_out_scale_inv.item()) - if fp8_meta["recipe"].fp8_mha: + if fp8_output: out_index, meta_tensor, output_te_dtype, output_dtype = ( tex.FP8FwdTensors.GEMM1_OUTPUT, fp8_meta["scaling_fwd"], @@ -765,6 +766,7 @@ def backward( None, # ub_overlap_rs_dgrad None, # ub_overlap_ag None, # ub_name + None, # fp8_output None, # fsdp_group ) @@ -1117,6 +1119,7 @@ def forward( self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None, + fp8_output: Optional[bool] = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply layer normalization to the input followed by a linear transformation. @@ -1244,6 +1247,7 @@ def forward( self.ub_overlap_rs_dgrad, self.ub_overlap_ag, self.ub_name, + fp8_output, self.fsdp_group, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 175e5ab5cf..f92a2db2d9 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -82,12 +82,10 @@ def forward( ub_overlap_rs: bool, ub_overlap_ag: bool, ub_name: str, - is_first_module_in_mha: bool, + fp8_output: bool, fsdp_group: Union[dist_group_type, None], ) -> torch.Tensor: is_input_fp8 = isinstance(inp, Float8Tensor) - if is_input_fp8: - fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT] = inp._scale_inv[0] # Make sure input dimensions are compatible in_features = weight.shape[-1] @@ -110,14 +108,6 @@ def forward( fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) if isinstance(inputmat, Float8Tensor): inputmat_scale_inv = inputmat._scale_inv - if ( - not fp8_meta["recipe"].override_linear_precision.wgrad - and is_grad_enabled - and weight.requires_grad - and not sequence_parallel - ): - # FP8 input for forward, FP8 input transpose for backward wgrad - inputmat_t = inputmat.transpose_2d() else: inputmat_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device) if ( @@ -171,7 +161,7 @@ def forward( assert isinstance(weight_fp8, Float8Tensor) - if is_first_module_in_mha: + if fp8_output: proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( tex.FP8FwdTensors.GEMM1_OUTPUT, fp8_meta["scaling_fwd"], @@ -240,7 +230,7 @@ def forward( fp8_meta_tensor=meta_tensor, D_dtype=proj_out_tetype, ) - if is_first_module_in_mha: + if fp8_output: out = Float8Tensor( data=out, fp8_meta=fp8_meta, @@ -639,7 +629,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, # ub_overlap_rs None, # ub_overlap_ag None, # ub_name - None, # is_first_module_in_mha + None, # fp8_output None, # fsdp_group ) @@ -917,7 +907,7 @@ def forward( self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None, - is_first_module_in_mha: Optional[bool] = False, + fp8_output: Optional[bool] = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply the linear transformation to the input. @@ -951,8 +941,6 @@ def forward( allow_non_contiguous=isinstance(inp, Float8Tensor), ) as inp: - is_first_module_in_mha = is_first_module_in_mha and self.fp8_meta["recipe"].fp8_mha - # Get concatenated weight and bias tensors unfused_weights = [getattr(self, name) for name in self.weight_names] if any(isinstance(w, Float8Tensor) for w in unfused_weights): @@ -1037,7 +1025,7 @@ def forward( self.ub_overlap_rs, self.ub_overlap_ag, self.ub_name, - is_first_module_in_mha, + fp8_output, self.fsdp_group, ) out = linear_fn(*args) From 454e389502ad4ed4f90b0990a631fe12bdf968fd Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Thu, 5 Sep 2024 10:54:08 -0700 Subject: [PATCH 25/50] Added offloading support FP8 attention (#1131) * Added offloading support FP8 attention Signed-off-by: Selvaraj Anandaraj * Update transformer_engine/pytorch/attention.py Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Selvaraj Anandaraj * Fix Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Selvaraj Anandaraj Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/attention.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 59bc26140d..91c14899ec 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5698,16 +5698,23 @@ def forward( out_save = out_ret fp8_tensors = (None, None, None, None, None, None) + ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + from .cpu_offload import CPUOffloadEnabled if CPUOffloadEnabled: - tensor_list = [q, k, v, out_save, cu_seqlens_q, cu_seqlens_kv] + if ctx.fp8: + tensor_list = fp8_tensors + else: + tensor_list = [q, k, v, out_save] + + tensor_list.extend(aux_ctx_tensors) + qkv_layout = "sbhd_sbhd_sbhd" for tensor in tensor_list: if tensor is not None: tensor.activation_offloading = True - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) From 215db88dcd294ef4ddbbb635cd51956a35fc1e4f Mon Sep 17 00:00:00 2001 From: Ruibin Cheung Date: Fri, 6 Sep 2024 02:03:24 +0800 Subject: [PATCH 26/50] [PyTorch] Implement Fp8 padding and unpadding module (#1129) * [TE/PyTorch][MoE] Add FP8 padding and unpadding module 1. Add multi-tensor padding kernel for FP8 with padding size = 16. 2. Add FP8Padding and Fp8Unpadding module 3. Add Padded GroupedLinear unit tests --------- Signed-off-by: beinggod Co-authored-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> --- tests/cpp/operator/CMakeLists.txt | 1 + tests/cpp/operator/test_multi_padding.cu | 169 ++++++++++++++ tests/pytorch/test_numerics.py | 189 +++++++++++++++ transformer_engine/common/CMakeLists.txt | 1 + .../include/transformer_engine/padding.h | 51 ++++ transformer_engine/common/util/padding.cu | 219 ++++++++++++++++++ transformer_engine/pytorch/__init__.py | 1 + .../pytorch/cpp_extensions/__init__.py | 1 + .../pytorch/cpp_extensions/padding.py | 29 +++ transformer_engine/pytorch/csrc/common.h | 1 + transformer_engine/pytorch/csrc/extensions.h | 8 + .../pytorch/csrc/extensions/padding.cu | 79 +++++++ .../pytorch/csrc/extensions/pybind.cpp | 3 +- transformer_engine/pytorch/module/__init__.py | 2 + .../pytorch/module/fp8_padding.py | 123 ++++++++++ .../pytorch/module/fp8_unpadding.py | 119 ++++++++++ 16 files changed, 995 insertions(+), 1 deletion(-) create mode 100644 tests/cpp/operator/test_multi_padding.cu create mode 100644 transformer_engine/common/include/transformer_engine/padding.h create mode 100644 transformer_engine/common/util/padding.cu create mode 100644 transformer_engine/pytorch/cpp_extensions/padding.py create mode 100644 transformer_engine/pytorch/csrc/extensions/padding.cu create mode 100644 transformer_engine/pytorch/module/fp8_padding.py create mode 100644 transformer_engine/pytorch/module/fp8_unpadding.py diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index e590d8e92a..45806e7022 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -13,6 +13,7 @@ add_executable(test_operator test_layernorm.cu test_rmsnorm.cu test_multi_cast_transpose.cu + test_multi_padding.cu test_causal_softmax.cu ../test_common.cu) diff --git a/tests/cpp/operator/test_multi_padding.cu b/tests/cpp/operator/test_multi_padding.cu new file mode 100644 index 0000000000..e9e42725fe --- /dev/null +++ b/tests/cpp/operator/test_multi_padding.cu @@ -0,0 +1,169 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +template +void compute_ref(const std::vector>& input_list, + std::vector>& output_list, + const std::vector& height_list, + const std::vector& width_list, + const std::vector& padded_height_list) { + using compute_t = float; + for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { + const auto& input = input_list[tensor_id]; + auto& output = output_list[tensor_id]; + const size_t height = height_list[tensor_id]; + const size_t width = width_list[tensor_id]; + const size_t padded_height = padded_height_list[tensor_id]; + + for (size_t i = 0; i < padded_height; ++i) { + if (i < height) { + for (size_t j = 0; j < width; ++j) { + const compute_t x = static_cast(input[i * width + j]); + const OutputType y = static_cast(x); + output[i * width + j] = y; + } + } else { + for (size_t j = 0; j < width; ++j) { + output[i * width + j] = static_cast(0.f); + } + } + } + } +} + +template +void performTest() { + using namespace test; + + const DType itype = TypeInfo::dtype; + const DType otype = TypeInfo::dtype; + const std::vector> tensor_dims = {{1,1}, + {1,768}, + {768,1}, + {768,768}, + {43,43}, + {43,256}, + {256,43}, + {256,256}}; + const size_t num_tensors = tensor_dims.size(); + constexpr int align = 16; + + // Buffers for Transformer Engine implementation + std::vector input_list, output_list, output_t_list; + + // Buffers for reference implementation + std::vector> ref_input_list; + std::vector> ref_output_list; + std::vector ref_height_list(num_tensors), ref_width_list(num_tensors); + std::vector ref_padded_height_list(num_tensors); + + // Initialize buffers + for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { + const size_t height = tensor_dims[tensor_id].first; + const size_t width = tensor_dims[tensor_id].second; + const size_t padded_height = (height + align - 1) / align * align; + input_list.emplace_back(Tensor({ height, width }, itype)); + output_list.emplace_back(Tensor({ padded_height, width }, otype)); + + auto& input = input_list.back(); + auto& output = output_list.back(); + fillUniform(&input); + setRandomScale(&output); + + ref_input_list.emplace_back(height*width); + ref_output_list.emplace_back(padded_height*width); + + std::copy(input.cpu_dptr(), + input.cpu_dptr() + height * width, + ref_input_list.back().begin()); + ref_height_list[tensor_id] = height; + ref_width_list[tensor_id] = width; + ref_padded_height_list[tensor_id] = padded_height; + } + + // Transformer Engine implementation + auto make_nvte_vector = [](std::vector& tensor_list) + -> std::vector { + std::vector nvte_tensor_list; + for (auto& tensor : tensor_list) { + nvte_tensor_list.emplace_back(tensor.data()); + } + return nvte_tensor_list; + }; + nvte_multi_padding(num_tensors, + make_nvte_vector(input_list).data(), + make_nvte_vector(output_list).data(), + ref_padded_height_list.data(), + 0); + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + // Reference implementation + compute_ref(ref_input_list, + ref_output_list, + ref_height_list, + ref_width_list, + ref_padded_height_list); + + // Check correctness + for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { + auto [atol, rtol] = getTolerances(otype); + compareResults("output", + output_list[tensor_id], + ref_output_list[tensor_id].data(), + atol, rtol); + } +} + +} // namespace + +class MultiPaddingTestSuite + : public ::testing::TestWithParam< + transformer_engine::DType> {}; + +TEST_P(MultiPaddingTestSuite, TestMultiPaddingTranspose) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = GetParam(); + const DType output_type = input_type; + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performTest(); + ); + ); +} + + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + MultiPaddingTestSuite, + ::testing::ValuesIn(test::all_fp_types), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(info.param); + return name; + }); diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 85cd4fc256..723f68369b 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -7,6 +7,7 @@ from typing import Dict, List, Optional import pytest import copy +import random import torch import torch.nn as nn @@ -30,6 +31,8 @@ TransformerLayer, LayerNorm, InferenceParams, + Fp8Padding, + Fp8Unpadding, ) from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm @@ -354,6 +357,40 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return (input > 0) * input * input +class TorchGroupedLinearWithPadding(nn.Module): + + def __init__( + self, num_gemms, in_features, out_features, bias, params_dtype, parallel_mode, fp8 + ) -> None: + super().__init__() + + self.padding = Fp8Padding(num_gemms) + self.linear_fn = GroupedLinear( + num_gemms, + in_features, + out_features, + bias=bias, + params_dtype=params_dtype, + parallel_mode=parallel_mode, + device="cuda", + ) + self.unpadding = Fp8Unpadding(num_gemms) + + self.fp8 = fp8 + + def forward(self, inp: torch.Tensor, m_splits: List[int]) -> torch.Tensor: + if self.fp8: + orig_m_splits = m_splits + inp, m_splits = self.padding(inp, m_splits) + + out = self.linear_fn(inp, m_splits) + + if self.fp8: + out = self.unpadding(out, orig_m_splits) + + return out + + _supported_act = { "geglu": nn.GELU(approximate="tanh"), "gelu": nn.GELU(approximate="tanh"), @@ -1328,6 +1365,158 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode): ) +def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False): + + def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): + """Padding tensor shapes to multiples of 16.""" + padded_tokens_per_expert = [ + (num_tokens + 15) // 16 * 16 for num_tokens in tokens_per_expert + ] + hidden_states = torch.split(hidden_states, tokens_per_expert) + padded_hidden_states = [] + for hidden_state, actual_num_tokens, padded_num_tokens in zip( + hidden_states, tokens_per_expert, padded_tokens_per_expert + ): + padded_hidden_states.append(hidden_state) + if padded_num_tokens > actual_num_tokens: + pad_tensor = torch.zeros( + padded_num_tokens - actual_num_tokens, + hidden_state.shape[1], + dtype=hidden_state.dtype, + device=hidden_state.device, + ) + padded_hidden_states.append(pad_tensor) + padded_hidden_states = torch.cat(padded_hidden_states, dim=0) + return padded_hidden_states, padded_tokens_per_expert + + def _unpad_tensor_for_fp8(padded_hidden_states, actual_tokens_per_expert, tokens_per_expert): + inputmats = torch.split( + padded_hidden_states.view(-1, padded_hidden_states.shape[-1]), tokens_per_expert + ) + hidden_states = torch.cat( + [ + grad_output_mat[: actual_tokens_per_expert[i]] + for i, grad_output_mat in enumerate(inputmats) + ], + dim=0, + ) + + return hidden_states + + def _generate_random_numbers(n, total_sum): + if n <= 0: + return [] + + # reset seed + random.seed(seed) + + breaks = sorted(random.sample(range(1, total_sum), n - 1)) + random_numbers = ( + [breaks[0]] + + [breaks[i] - breaks[i - 1] for i in range(1, n - 1)] + + [total_sum - breaks[-1]] + ) + + return random_numbers + + reset_rng_states() + if fp8: + FP8GlobalStateManager.reset() + + inp_hidden_states = torch.randn( + (config.seq_len * bs, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) + inp_hidden_states.retain_grad() + + m_splits = _generate_random_numbers(num_gemms, config.seq_len * bs) + + with fp8_autocast(enabled=fp8): + if isinstance(block, TorchGroupedLinearWithPadding): + out = block(inp_hidden_states, m_splits) + else: + if fp8: + padded_inp_hidden_states, padding_m_splits = _pad_tensor_for_fp8( + inp_hidden_states, m_splits + ) + padded_inp_hidden_states = block(padded_inp_hidden_states, padding_m_splits) + out = _unpad_tensor_for_fp8(padded_inp_hidden_states, m_splits, padding_m_splits) + else: + out = block(inp_hidden_states, m_splits) + + loss = out.sum() + loss.backward() + + torch.cuda.synchronize() + outputs = [out, inp_hidden_states.grad] + for p in block.parameters(): + if p.requires_grad: + outputs.append(p.grad) + return outputs + + +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("num_gemms", [3, 6]) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("fp8", [True]) +@pytest.mark.parametrize("fp8_model_params", all_boolean) +def test_padding_grouped_linear_accuracy( + dtype, num_gemms, bs, model, fp8, fp8_model_params, parallel_mode=None +): + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + + config = model_configs[model] + if config.seq_len % 16 != 0 and fp8: + pytest.skip("FP8 requires sequence length to be divisible by 16.") + + with fp8_model_init(enabled=fp8 and fp8_model_params): + grouped_linear = TorchGroupedLinearWithPadding( + num_gemms, + config.hidden_size, + 4 * config.hidden_size, + bias=False, + params_dtype=dtype, + parallel_mode=parallel_mode, + fp8=fp8, + ).eval() + + with fp8_model_init(enabled=fp8 and fp8_model_params): + ref_grouped_linear = GroupedLinear( + num_gemms, + config.hidden_size, + 4 * config.hidden_size, + bias=False, + params_dtype=dtype, + parallel_mode=parallel_mode, + device="cuda", + ).eval() + + # Share params + with torch.no_grad(): + inner_grouped_linear = grouped_linear.linear_fn + for i in range(num_gemms): + setattr( + ref_grouped_linear, + f"weight{i}", + Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()), + ) + + outputs = _test_padding_grouped_linear_accuracy( + grouped_linear, num_gemms, bs, dtype, config, fp8 + ) + outputs_ref = _test_padding_grouped_linear_accuracy( + ref_grouped_linear, num_gemms, bs, dtype, config, fp8 + ) + + # Shoule be bit-wise match + for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + + def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph): reset_rng_states() diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index a6fd6815c3..647d2c474d 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -71,6 +71,7 @@ list(APPEND transformer_engine_SOURCES rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu rmsnorm/rmsnorm_fwd_cuda_kernel.cu util/cast.cu + util/padding.cu util/cuda_driver.cpp util/cuda_runtime.cpp util/rtc.cpp diff --git a/transformer_engine/common/include/transformer_engine/padding.h b/transformer_engine/common/include/transformer_engine/padding.h new file mode 100644 index 0000000000..a419b38234 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/padding.h @@ -0,0 +1,51 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file padding.h + * \brief Functions handling padding. + */ + +#ifndef TRANSFORMER_ENGINE_PADDING_H_ +#define TRANSFORMER_ENGINE_PADDING_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Padding multiple tensors. + * + * NOTE: Padding mode only support bottom. + * + * For example, 3x3 matrix pad to 4x3 matrix. + * + * source + * | 1 | 2 | 3 | + * | 4 | 5 | 6 | + * | 7 | 8 | 9 | + * + * destination + * | 1 | 2 | 3 | + * | 4 | 5 | 6 | + * | 7 | 8 | 9 | + * | 0 | 0 | 0 | + * + * \param[in] num_tensors Number of tensors. + * \param[in] input_list List of 2D input tensors. + * \param[in,out] output_list List of padded tensors. Dimensions + * match tensors in input_list. + * \param[in] padded_num_rows_list List of padded num rows corresponding to input tensors. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list, + const int* padded_num_rows_list, cudaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_PADDING_H_ diff --git a/transformer_engine/common/util/padding.cu b/transformer_engine/common/util/padding.cu new file mode 100644 index 0000000000..017d2e6a56 --- /dev/null +++ b/transformer_engine/common/util/padding.cu @@ -0,0 +1,219 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include +#include +#include + +#include "../common.h" +#include "../utils.cuh" + +namespace transformer_engine { + +namespace { + +// Parameters to tune +constexpr int n_warps_per_tile = 4; +constexpr int threads_per_block = THREADS_PER_WARP * n_warps_per_tile; +constexpr int desired_load_store_size = 8; +constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB + +struct MultiPaddingArgs { + // (input) Data buffers for input tensors + void* input_list[kMaxTensorsPerKernel]; + // (output) Data buffers for cast output tensors + void* output_list[kMaxTensorsPerKernel]; + // Input matrix heights + int num_rows_list[kMaxTensorsPerKernel]; + // Input matrix heights (padded) + int padded_num_rows_list[kMaxTensorsPerKernel]; + // Input matrix widths + int row_length_list[kMaxTensorsPerKernel]; + // tensor + int block_range[kMaxTensorsPerKernel + 1]; + // Number of tensors being processed by kernel + int num_tensors; +}; + +template +__global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiPaddingArgs args) { + using Vec = Vec; + + // Thread indices + // Note: Block is interpreted as a warp_size x num_warps grid + constexpr int bdimx = THREADS_PER_WARP; + constexpr int bdimy = n_warps_per_tile; + const int tid = threadIdx.x; + const int tidx = tid % bdimx; + const int tidy = tid / bdimx; + const int bid = blockIdx.x; + + // Input tensors are divided into tiles + // Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles + constexpr int tile_dim_m = THREADS_PER_WARP * nvec; + constexpr int tile_dim_n = THREADS_PER_WARP * nvec; + + // Number of nvec x nvec subtiles for each thread to + // load/store + constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile; + + // Find tensor corresponding to block + int tensor_id = 0; + while (args.block_range[tensor_id + 1] <= bid) { + ++tensor_id; + } + const Type* input = reinterpret_cast(args.input_list[tensor_id]); + Type* output = reinterpret_cast(args.output_list[tensor_id]); + const int num_rows = args.num_rows_list[tensor_id]; + const int padded_num_rows = args.padded_num_rows_list[tensor_id]; + const int row_length = args.row_length_list[tensor_id]; + + // Find position of tile within tensor + const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; + const int tile_id = bid - args.block_range[tensor_id]; + const int tile_id_m = tile_id / num_tiles_n; + const int tile_id_n = tile_id % num_tiles_n; + const int tile_row = tile_id_m * tile_dim_m; + const int tile_col = tile_id_n * tile_dim_n; + + // Load input and store to registers + // Note: Each thread loads n_iterations subtiles, casts to output + // type, and transposes in registers. + Type local_zero = static_cast(0.f); +#pragma unroll + for (int iter = 0; iter < n_iterations; ++iter) { + const int i1 = tidy + iter * bdimy; + const int j1 = tidx; +#pragma unroll + for (int i2 = 0; i2 < nvec; ++i2) { + const int row = tile_row + i1 * nvec + i2; + const int col = tile_col + j1 * nvec; + Vec local_input; + Vec local_output; + local_input.clear(); + if (row < num_rows) { + for (int j2 = 0; j2 < nvec; ++j2) { + if (col + j2 < row_length) { + local_input.data.elt[j2] = input[row * row_length + col + j2]; + } + } + } +#pragma unroll + for (int j2 = 0; j2 < nvec; ++j2) { + local_output.data.elt[j2] = local_input.data.elt[j2]; + } + if (row < num_rows) { + for (int j2 = 0; j2 < nvec; ++j2) { + if (col + j2 < row_length) { + output[row * row_length + col + j2] = local_output.data.elt[j2]; + } + } + } else if (row < padded_num_rows) { + // padding + for (int j2 = 0; j2 < nvec; ++j2) { + if (col + j2 < row_length) { + output[row * row_length + col + j2] = local_zero; + } + } + } + } + } +} + +} // namespace + +void multi_padding(const std::vector input_list, std::vector output_list, + const std::vector padded_num_rows_list, cudaStream_t stream) { + // Check that number of tensors is valid + NVTE_CHECK(output_list.size() == input_list.size(), + "Number of input and output tensors must match"); + if (input_list.empty()) { + return; + } + + // Check that tensor properties are valid + DType type = input_list[0]->data.dtype; + for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { + const auto& input = *input_list[tensor_id]; + const auto& output = *output_list[tensor_id]; + CheckInputTensor(input, "multi_padding_input_" + std::to_string(tensor_id)); + CheckInputTensor(output, "multi_padding_output_" + std::to_string(tensor_id)); + + NVTE_CHECK(input.data.dtype == type, "Input tensor types do not match."); + NVTE_CHECK(output.data.dtype == type, "Output tensor types do not match."); + + NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions."); + NVTE_CHECK(output.data.shape[0] == padded_num_rows_list[tensor_id], + "output tensor shape does not match padded input shape."); + } + + // Input matrices are divided into tiles + // Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles + const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size / typeToSize(type); + const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size / typeToSize(type); + + // Add tensors to kernel argument struct + MultiPaddingArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.block_range[0] = 0; + for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { + // Launch kernel if argument struct is full + if (kernel_args.num_tensors == kMaxTensorsPerKernel) { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type); + const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; + multi_padding_kernel + <<>>(kernel_args);); // NOLINT(*) + kernel_args.num_tensors = 0; + } + + // Calculate number of thread blocks needed for tensor + const int num_rows = input_list[tensor_id]->data.shape[0]; + const int padded_num_rows = padded_num_rows_list[tensor_id]; + const int row_length = input_list[tensor_id]->data.shape[1]; + const int num_tiles_m = (padded_num_rows + tile_dim_m - 1) / tile_dim_m; + const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; + const int num_tiles = num_tiles_m * num_tiles_n; + + // Add tensor to kernel argument struct + const int pos = kernel_args.num_tensors; + kernel_args.input_list[pos] = const_cast(input_list[tensor_id]->data.dptr); + kernel_args.output_list[pos] = output_list[tensor_id]->data.dptr; + kernel_args.num_rows_list[pos] = num_rows; + kernel_args.padded_num_rows_list[pos] = padded_num_rows; + kernel_args.row_length_list[pos] = row_length; + kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles; + kernel_args.num_tensors++; + } + + // Launch kernel + if (kernel_args.num_tensors > 0) { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type); + const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; + multi_padding_kernel + <<>>(kernel_args);); // NOLINT(*) + } +} + +} // namespace transformer_engine + +void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list, + const int* padded_num_rows_list, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_padding); + using namespace transformer_engine; + std::vector input_list_, output_list_; + std::vector padded_num_rows_list_; + for (size_t i = 0; i < num_tensors; ++i) { + input_list_.push_back(reinterpret_cast(const_cast(input_list[i]))); + output_list_.push_back(reinterpret_cast(output_list[i])); + padded_num_rows_list_.push_back(padded_num_rows_list[i]); + } + multi_padding(input_list_, output_list_, padded_num_rows_list_, stream); +} diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 89b20002a7..c4097333d3 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -67,6 +67,7 @@ def _load_library(): from transformer_engine.pytorch.module import LayerNorm from transformer_engine.pytorch.module import RMSNorm from transformer_engine.pytorch.module import GroupedLinear +from transformer_engine.pytorch.module import Fp8Padding, Fp8Unpadding from transformer_engine.pytorch.module import initialize_ub from transformer_engine.pytorch.module import destroy_ub from transformer_engine.pytorch.attention import DotProductAttention diff --git a/transformer_engine/pytorch/cpp_extensions/__init__.py b/transformer_engine/pytorch/cpp_extensions/__init__.py index 61d688f3f4..9f3c1b2424 100644 --- a/transformer_engine/pytorch/cpp_extensions/__init__.py +++ b/transformer_engine/pytorch/cpp_extensions/__init__.py @@ -11,3 +11,4 @@ from .activation import * from .normalization import * from .cast import * +from .padding import * diff --git a/transformer_engine/pytorch/cpp_extensions/padding.py b/transformer_engine/pytorch/cpp_extensions/padding.py new file mode 100644 index 0000000000..41dfbe2466 --- /dev/null +++ b/transformer_engine/pytorch/cpp_extensions/padding.py @@ -0,0 +1,29 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Python interface for transpose extensions""" +from typing import List, Tuple, Union +import torch +import transformer_engine_torch as tex + + +__all__ = [ + "multi_padding_fused", +] + + +def multi_padding_fused( + inp: torch.Tensor, + row_list: List[int], + padded_row_list: List[int], + out: torch.Tensor, +) -> Union[Tuple[List[torch.Tensor], List[torch.Tensor]], None]: + """Padding""" + + tex.fused_multi_row_padding( + inp, + out, + row_list, + padded_row_list, + ) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 7fb9953f94..04a1193a71 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 45ef9951d7..31103cbe8e 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -486,4 +486,12 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, float momentum, float dampening, float lr, bool nesterov, bool first_run, bool wd_after_momentum, float scale); +/*************************************************************************************************** + * padding + **************************************************************************************************/ + +void fused_multi_row_padding(at::Tensor input, at::Tensor output, + std::vector input_row_list, + std::vector padded_input_row_list); + #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ diff --git a/transformer_engine/pytorch/csrc/extensions/padding.cu b/transformer_engine/pytorch/csrc/extensions/padding.cu new file mode 100644 index 0000000000..d975ebeeef --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/padding.cu @@ -0,0 +1,79 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "extensions.h" + +void fused_multi_row_padding(at::Tensor input, at::Tensor output, + std::vector input_row_list, + std::vector padded_input_row_list) { + using namespace transformer_engine; + + NVTE_CHECK(input_row_list.size() == padded_input_row_list.size(), + "Number of input row list and padded row list must match."); + NVTE_CHECK(input.dim() == 2, "Dimension of input must equal 2."); + NVTE_CHECK(output.dim() == 2, "Dimension of output must equal 2."); + + const int num_tensors = input_row_list.size(); + // Extract properties from PyTorch tensors + std::vector input_dptr_list, output_dptr_list; + std::vector> input_shape_list, output_shape_list; + std::vector input_type_list; + void* d_input_ptr = reinterpret_cast(input.data_ptr()); + void* d_output_ptr = reinterpret_cast(output.data_ptr()); + for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { + input_dptr_list.push_back(d_input_ptr); + output_dptr_list.push_back(d_output_ptr); + + // Move the input pointer to the next split. + char* input_char_ptr = reinterpret_cast(d_input_ptr); + const size_t input_dptr_offset = + input_row_list[tensor_id] * input.size(1) * input.element_size(); + input_char_ptr += input_dptr_offset; + d_input_ptr = reinterpret_cast(input_char_ptr); + + input_shape_list.push_back({input_row_list[tensor_id], static_cast(input.size(1))}); + input_type_list.push_back(GetTransformerEngineDType(input.scalar_type())); + + // Move the output pointer to the next split. + char* output_char_ptr = reinterpret_cast(d_output_ptr); + const size_t output_dptr_offset = + padded_input_row_list[tensor_id] * output.size(1) * output.element_size(); + output_char_ptr += output_dptr_offset; + d_output_ptr = reinterpret_cast(output_char_ptr); + + output_shape_list.push_back( + {padded_input_row_list[tensor_id], static_cast(output.size(1))}); + } + + // Construct TE tensors + std::vector nvte_input_list, nvte_output_list; + std::vector tensor_wrappers; + auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, + transformer_engine::DType dtype) -> NVTETensor { + tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype)); + return tensor_wrappers.back().data(); + }; + + std::vector padded_num_rows_list; + for (size_t i = 0; i < input_dptr_list.size(); ++i) { + if (input_dptr_list[i] == nullptr || input_row_list[i] == 0) continue; + nvte_input_list.emplace_back( + make_tensor(input_dptr_list[i], input_shape_list[i], input_type_list[i])); + nvte_output_list.emplace_back( + make_tensor(output_dptr_list[i], output_shape_list[i], input_type_list[i])); + padded_num_rows_list.emplace_back(padded_input_row_list[i]); + } + + // Check tensor lists + NVTE_CHECK(nvte_output_list.size() == nvte_input_list.size(), + "Number of input and output tensors must match"); + NVTE_CHECK(padded_num_rows_list.size() == nvte_input_list.size() && + "Number of input and padded row list must match"); + + // Launch TE kernel + nvte_multi_padding(nvte_input_list.size(), nvte_input_list.data(), nvte_output_list.data(), + padded_num_rows_list.data(), at::cuda::getCurrentCUDAStream()); +} diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index dc82b6e2df..7bd5a2d8c8 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -152,7 +152,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction, "Update amax history and FP8 scale/scale_inv after reduction", py::call_guard()); - + m.def("fused_multi_row_padding", &fused_multi_row_padding, "Fused Multi-tensor padding", + py::call_guard()); // fused apply rope m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD", py::call_guard()); diff --git a/transformer_engine/pytorch/module/__init__.py b/transformer_engine/pytorch/module/__init__.py index 6994f586b1..ba4755efe3 100644 --- a/transformer_engine/pytorch/module/__init__.py +++ b/transformer_engine/pytorch/module/__init__.py @@ -9,4 +9,6 @@ from .layernorm_mlp import LayerNormMLP from .layernorm import LayerNorm from .rmsnorm import RMSNorm +from .fp8_padding import Fp8Padding +from .fp8_unpadding import Fp8Unpadding from .base import initialize_ub, destroy_ub diff --git a/transformer_engine/pytorch/module/fp8_padding.py b/transformer_engine/pytorch/module/fp8_padding.py new file mode 100644 index 0000000000..60bac91353 --- /dev/null +++ b/transformer_engine/pytorch/module/fp8_padding.py @@ -0,0 +1,123 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""FP8 Padding API""" + +from typing import Union, List + +import torch + +from ..cpp_extensions import ( + multi_padding_fused, +) +from ..jit import no_torch_dynamo + + +__all__ = ["Fp8Padding"] + + +class _Fp8Padding(torch.autograd.Function): + """functional FP8 padding""" + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + m_splits: List[int], + padded_m_splits: List[int], + is_grad_enabled: bool, + ) -> torch.Tensor: + # Make sure input dimensions are compatible + in_features = inp.shape[-1] + + # Allocate cast and transpose output tensor + total_row = sum(padded_m_splits) + out = torch.empty([total_row, in_features], dtype=inp.dtype, device=inp.device) + + multi_padding_fused(inp.view(-1, in_features), m_splits, padded_m_splits, out) + + if is_grad_enabled: + ctx.m_splits = m_splits + ctx.padded_m_splits = padded_m_splits + ctx.requires_dgrad = inp.requires_grad + + return out + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + + grad_input = None + if ctx.requires_dgrad: + grad_output = grad_output.contiguous() + + grad_output_mats = torch.split( + grad_output.view(-1, grad_output.shape[-1]), ctx.padded_m_splits + ) + grad_input = torch.cat( + [ + grad_output_mat[: ctx.m_splits[i]] + for i, grad_output_mat in enumerate(grad_output_mats) + ], + dim=0, + ) + + return (grad_input, None, None, None) + + +class Fp8Padding(torch.nn.Module): + """ + Apply the padding for Grouped GEMM input. + + Parameters + ---------- + num_gemms: int + number of GEMMs to be performed simutaneously. + """ + + def __init__( + self, + num_gemms, + ) -> None: + super().__init__() + + self.num_gemms = num_gemms + + @no_torch_dynamo() + def forward( + self, + inp: torch.Tensor, + m_splits: List[int], + ) -> Union[torch.Tensor, List[int]]: + """ + Apply the padding to the input. + + Parameters + ---------- + inp : torch.Tensor + Input tensor. + m_splits : List[int] + List of integers representing the split of the input tensor. + """ + + assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." + + # FP8 padding calculate + padded_m_splits = [(m + 15) // 16 * 16 for m in m_splits] + + if torch.is_grad_enabled(): + fn = _Fp8Padding.apply + args = [] + else: + fn = _Fp8Padding.forward + args = [None] + + args += ( + inp, + m_splits, + padded_m_splits, + torch.is_grad_enabled(), + ) + out = fn(*args) + + return out, padded_m_splits diff --git a/transformer_engine/pytorch/module/fp8_unpadding.py b/transformer_engine/pytorch/module/fp8_unpadding.py new file mode 100644 index 0000000000..6e08f849ef --- /dev/null +++ b/transformer_engine/pytorch/module/fp8_unpadding.py @@ -0,0 +1,119 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""FP8 Padding API""" + +from typing import List + +import torch + +from ..cpp_extensions import ( + multi_padding_fused, +) +from ..jit import no_torch_dynamo + + +__all__ = ["Fp8Unpadding"] + + +class _Fp8Unpadding(torch.autograd.Function): + """functional FP8 unpadding""" + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + m_splits: List[int], + padded_m_splits: List[int], + is_grad_enabled: bool, + ) -> torch.Tensor: + inputmats = torch.split(inp.view(-1, inp.shape[-1]), padded_m_splits) + out_ret = torch.cat( + [grad_output_mat[: m_splits[i]] for i, grad_output_mat in enumerate(inputmats)], dim=0 + ) + + if is_grad_enabled: + ctx.m_splits = m_splits + ctx.padded_m_splits = padded_m_splits + ctx.requires_dgrad = inp.requires_grad + + return out_ret + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + grad_input = None + if ctx.requires_dgrad: + grad_output = grad_output.contiguous() + + in_features = grad_output.shape[-1] + + # Allocate cast and transpose output tensor + total_row = sum(ctx.padded_m_splits) + grad_input = torch.empty( + [total_row, in_features], dtype=grad_output.dtype, device=grad_output.device + ) + # FP8 pad input for forward, FP8 input transpose for backward wgrad + multi_padding_fused( + grad_output.view(-1, in_features), ctx.m_splits, ctx.padded_m_splits, grad_input + ) + + return (grad_input, None, None, None) + + +class Fp8Unpadding(torch.nn.Module): + """ + Apply the unpadding for Grouped GEMM input. + + Parameters + ---------- + num_gemms: int + number of GEMMs to be performed simutaneously. + """ + + def __init__( + self, + num_gemms, + ) -> None: + super().__init__() + + self.num_gemms = num_gemms + + @no_torch_dynamo() + def forward( + self, + inp: torch.Tensor, + m_splits: List[int], + ) -> torch.Tensor: + """ + Apply the unpadding to the input. + + Parameters + ---------- + inp : torch.Tensor + Input tensor. + m_splits : List[int] + List of integers representing the split of the input tensor. + """ + + assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." + + # FP8 padding calculate + padded_m_splits = [(m + 15) // 16 * 16 for m in m_splits] + + if torch.is_grad_enabled(): + fn = _Fp8Unpadding.apply + args = [] + else: + fn = _Fp8Unpadding.forward + args = [None] + + args += ( + inp, + m_splits, + padded_m_splits, + torch.is_grad_enabled(), + ) + out = fn(*args) + + return out From 206c1d9220ed70bd5d4959194934e9cb6740e0fd Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 5 Sep 2024 11:21:33 -0700 Subject: [PATCH 27/50] [C] Suppress 128-D warning from cudnn-frontend (#1158) suppress 128D warning from cudnn-frontend Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/common/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 647d2c474d..deea5c58ad 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -135,6 +135,7 @@ set_source_files_properties(fused_softmax/scaled_masked_softmax.cu COMPILE_OPTIONS "--use_fast_math") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -diag-suppress 128") # Number of parallel build jobs if(ENV{MAX_JOBS}) From bdea56fc023014eaf52a171047b641b3bfdded70 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 5 Sep 2024 11:40:00 -0700 Subject: [PATCH 28/50] Revert "[C] Suppress 128-D warning from cudnn-frontend" (#1161) Revert "[C] Suppress 128-D warning from cudnn-frontend (#1158)" This reverts commit 206c1d9220ed70bd5d4959194934e9cb6740e0fd. Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index deea5c58ad..647d2c474d 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -135,7 +135,6 @@ set_source_files_properties(fused_softmax/scaled_masked_softmax.cu COMPILE_OPTIONS "--use_fast_math") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -diag-suppress 128") # Number of parallel build jobs if(ENV{MAX_JOBS}) From 047a50722780e7b647f9107783e210021190edc3 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Mon, 9 Sep 2024 22:30:48 +0800 Subject: [PATCH 29/50] [PyTorch] Propagate fp8 scale-inverse modification to `GroupedLinear` (#1128) * propagate scale_inv modification to GroupedLinear Signed-off-by: Xin Yao * optimization for separate scale_inv of weights and single output Signed-off-by: Xin Yao * let grouped gemm support different input combinations Signed-off-by: Xin Yao * fix type Signed-off-by: Xin Yao * add contiguous check Signed-off-by: Xin Yao * use len() instead of isinstance Signed-off-by: Xin Yao * fix ut Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Co-authored-by: Kirthi Shankar Sivamani --- tests/pytorch/test_numerics.py | 31 +++-- .../pytorch/cpp_extensions/gemm.py | 129 ++++++++++++------ .../pytorch/cpp_extensions/transpose.py | 3 +- transformer_engine/pytorch/csrc/extensions.h | 10 ++ .../pytorch/csrc/extensions/gemm.cu | 61 +++++++++ transformer_engine/pytorch/csrc/ts_fp8_op.cpp | 36 +++++ .../pytorch/module/grouped_linear.py | 43 ++++-- 7 files changed, 249 insertions(+), 64 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 723f68369b..ad34b4996f 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1266,12 +1266,15 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False ) inp_hidden_states.retain_grad() - m = config.seq_len // 16 - dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() - dist.append(dist[-1]) # Manually add a zero - m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) - m_splits = m_splits * 16 - assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms + if num_gemms > 1: + m = config.seq_len // 16 + dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() + dist.append(dist[-1]) # Manually add a zero + m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) + m_splits = m_splits * 16 + assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms + else: + m_splits = torch.tensor([config.seq_len]) with fp8_autocast(enabled=fp8): if isinstance(block, GroupedLinear): @@ -1353,7 +1356,7 @@ def test_grouped_linear_accuracy( @pytest.mark.parametrize("parallel_mode", ["column", "row"]) def test_grouped_linear_accuracy_parallel_mode(parallel_mode): - """Split the tests to reduce CI time""" + """Split the tests to save CI time""" test_grouped_linear_accuracy( dtype=torch.float32, num_gemms=6, @@ -1365,6 +1368,18 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode): ) +def test_grouped_linear_accuracy_single_gemm(): + """Split the tests to save CI time""" + test_grouped_linear_accuracy( + dtype=torch.float32, + num_gemms=1, + bs=2, + model=list(model_configs.keys())[0], + fp8=True, + fp8_model_params=True, + ) + + def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False): def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): @@ -2034,7 +2049,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate): fp8_grouped_gemm( A_fp8, - scale_inv, + [scale_inv], 0, # A_offset tex.DType.kFloat8E4M3, B_fp8, diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 8502f70491..fd1eb4a810 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -11,7 +11,12 @@ from ..utils import assert_dim_for_fp8_exec -__all__ = ["gemm", "fp8_gemm", "grouped_gemm", "fp8_grouped_gemm"] +__all__ = [ + "gemm", + "fp8_gemm", + "grouped_gemm", + "fp8_grouped_gemm", +] @functools.lru_cache(maxsize=None) @@ -313,7 +318,7 @@ def grouped_gemm( layout: str = "TN", bias: Optional[List[torch.Tensor]] = None, use_bias: bool = False, -) -> Tuple[Union[List[torch.Tensor], None], ...]: +) -> Tuple[List[torch.Tensor], ...]: """Non FP8 Grouped GEMM.""" assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." @@ -380,7 +385,7 @@ def grouped_gemm( def fp8_grouped_gemm( A: List[torch.Tensor], - A_scale_inv: torch.Tensor, + A_scale_inv: List[torch.Tensor], A_fp8_tensor_offset: int, A_dtype: tex.DType, B: List[torch.Tensor], @@ -390,6 +395,7 @@ def fp8_grouped_gemm( out: List[torch.Tensor], out_dtype: torch.dtype, workspaces: List[torch.Tensor], + m_splits: Optional[List[int]] = None, out_offset: Optional[int] = None, fp8_meta_tensor: tex.FP8TensorMeta = None, gelu: bool = False, @@ -398,16 +404,25 @@ def fp8_grouped_gemm( use_bias: bool = False, use_split_accumulator: bool = False, D_dtype: Optional[tex.DType] = None, -) -> Tuple[Union[List[torch.Tensor], None], ...]: +) -> Tuple[List[torch.Tensor], ...]: """ TN layout Grouped GEMM with fp8 inputs. - This method assumes the scale/scale_inv/amax of A/B/out is contiguous in the meta tensor. - scale: [ ...A_scale... | ...B_scale... | ...out_scale...] - scale_inv: [ ...A_scale_inv... | ...B_scale_inv... | ...out_scale_inv...] - amax: [ ...A_amax... | ...B_amax... | ...out_amax...] + Input requirements: + 1. If len(A_scale_inv) == num_gemms, len(out) must be 1, and m_splits is not None. + This is used for the calculation of output (fwd) and dgrad (bwd). + 2. if len(A_scale_inv) == 1, len(out) must be num_gemms. This is used for the + calculation of wgrad. """ - num_gemms = len(A) + if num_gemms > 1 and len(A_scale_inv) == num_gemms: + assert len(out) == 1 and m_splits is not None + elif num_gemms > 1 and len(A_scale_inv) == 1: + assert len(out) == num_gemms + elif num_gemms == 1: + assert len(A_scale_inv) == 1 and len(out) == 1 + else: + raise ValueError("Invalid input combinations of A_scale_inv and out.") + empty_tensor = _empty_tensor() empty_tensors = [empty_tensor] * num_gemms if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: @@ -420,41 +435,71 @@ def fp8_grouped_gemm( # Use bfloat16 as default bias_dtype bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype - if gelu: - gelu_input = [ - torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format) - for o in out - ] - else: - gelu_input = empty_tensors bias_dtype = TE_DType[bias_dtype] - + gelu_input = empty_tensors out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype - torch.ops.tex_ts.te_grouped_gemm_ts( - A, - A_scale_inv, - A_fp8_tensor_offset, - A_dtype, - True, # transa - B, - B_scale_inv, - B_fp8_tensor_offset, - B_dtype, - False, # transb - out, - 0 if out_offset is None else out_offset, - empty_tensor if out_offset is None else fp8_meta_tensor.scale, - out_dtype, - empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, - bias if use_bias else empty_tensors, - bias_dtype, - gelu_input, # this is pre_gelu_out - False, # grad - workspaces, - workspaces[0].shape[0], - accumulate, - use_split_accumulator, - ) + if len(A_scale_inv) == 1: + if gelu: + gelu_input = [ + torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format) + for o in out + ] + + torch.ops.tex_ts.te_grouped_gemm_ts( + A, + A_scale_inv[0], + A_fp8_tensor_offset, + A_dtype, + True, # transa + B, + B_scale_inv, + B_fp8_tensor_offset, + B_dtype, + False, # transb + out, + 0 if out_offset is None else out_offset, + empty_tensor if out_offset is None else fp8_meta_tensor.scale, + out_dtype, + empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, + bias if use_bias else empty_tensors, + bias_dtype, + gelu_input, # this is pre_gelu_out + False, # grad + workspaces, + workspaces[0].shape[0], + accumulate, + use_split_accumulator, + ) + else: + if gelu: + gelu_input = [torch.empty((m, A[0].size(0)), dtype=bias_dtype) for m in m_splits] + + torch.ops.tex_ts.te_grouped_gemm_single_output_ts( + A, + A_scale_inv, + A_fp8_tensor_offset, + A_dtype, + True, # transa + B, + B_scale_inv, + B_fp8_tensor_offset, + B_dtype, + False, # transb + m_splits, + out[0], + 0 if out_offset is None else out_offset, + empty_tensor if out_offset is None else fp8_meta_tensor.scale, + out_dtype, + empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, + bias if use_bias else empty_tensors, + bias_dtype, + gelu_input, # this is pre_gelu_out + False, # grad + workspaces, + workspaces[0].shape[0], + accumulate, + use_split_accumulator, + ) return out, gelu_input diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py index 37a1b59da2..ddc3b67e9e 100644 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ b/transformer_engine/pytorch/cpp_extensions/transpose.py @@ -175,6 +175,7 @@ def fp8_multi_cast_transpose_fused( amax_indices: List[int], scale_inv_indices: List[int], otype: tex.DType, + scale_inv: Optional[torch.Tensor] = None, ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: """Cast + Transpose with FP8 output""" @@ -182,7 +183,7 @@ def fp8_multi_cast_transpose_fused( input_list, fp8_meta_tensor.scale, fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + scale_inv if scale_inv is not None else fp8_meta_tensor.scale_inv, scale_indices, amax_indices, scale_inv_indices, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 31103cbe8e..c797208e06 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -165,6 +165,16 @@ void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count); +void te_grouped_gemm_single_output( + std::vector A, std::vector A_scale_inverse, int A_offset, + transformer_engine::DType A_type, bool transa, std::vector B, + at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb, + std::vector m_splits, at::Tensor D, int D_offset, at::Tensor D_scale, + transformer_engine::DType D_type, at::Tensor D_amax, std::vector bias, + transformer_engine::DType bias_type, std::vector pre_gelu_out, bool grad, + std::vector workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, int math_sm_count); + /*************************************************************************************************** * Transpose **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cu b/transformer_engine/pytorch/csrc/extensions/gemm.cu index 7405914a0e..ba9851e7e8 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cu +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cu @@ -151,3 +151,64 @@ void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int te_workspace.data(), accumulate, use_split_accumulator, math_sm_count, at::cuda::getCurrentCUDAStream()); } + +void te_grouped_gemm_single_output( + std::vector A, std::vector A_scale_inverse, int A_offset, + transformer_engine::DType A_type, bool transa, std::vector B, + at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb, + std::vector m_splits, at::Tensor D, int D_offset, at::Tensor D_scale, + transformer_engine::DType D_type, at::Tensor D_amax, std::vector bias, + transformer_engine::DType bias_type, std::vector pre_gelu_out, bool grad, + std::vector workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, int math_sm_count) { + using namespace transformer_engine; + std::vector te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace; + std::vector tensor_wrappers; + auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, + transformer_engine::DType dtype, void* amax_dptr, + void* scale_dptr, void* scale_inv_dptr) -> NVTETensor { + tensor_wrappers.emplace_back( + makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr)); + return tensor_wrappers.back().data(); + }; + NVTE_CHECK(D.is_contiguous(), "D must be contiguous."); + void* d_i_ptr = reinterpret_cast(D.data_ptr()); + for (size_t i = 0; i < A.size(); i++) { + if (m_splits[i] == 0) continue; + NVTE_CHECK(A[i].is_contiguous(), "A[", i, "] must be contiguous."); + NVTE_CHECK(B[i].is_contiguous(), "B[", i, "] must be contiguous."); + te_A.emplace_back(make_tensor( + A[i].data_ptr(), {static_cast(A[i].size(0)), static_cast(A[i].size(1))}, + A_type, nullptr, nullptr, getDataPtr(A_scale_inverse[i], A_offset))); + te_B.emplace_back(make_tensor( + B[i].data_ptr(), {static_cast(B[i].size(0)), static_cast(B[i].size(1))}, + B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i))); + te_D.emplace_back(make_tensor( + d_i_ptr, {static_cast(m_splits[i]), static_cast(A[i].size(0))}, D_type, + getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr)); + te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast(bias[i].size(0))}, + bias_type, nullptr, nullptr, nullptr)); + + const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr + ? std::vector{static_cast(pre_gelu_out[i].size(0))} + : std::vector{static_cast(pre_gelu_out[i].size(0)), + static_cast(pre_gelu_out[i].size(1))}; + te_pre_gelu_out.emplace_back(make_tensor( + pre_gelu_out[i].data_ptr(), gelu_shape, + GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr)); + // Move the D pointer to the next split. + char* char_ptr = reinterpret_cast(d_i_ptr); + char_ptr += m_splits[i] * A[i].size(0) * D.element_size(); + d_i_ptr = reinterpret_cast(char_ptr); + } + for (size_t i = 0; i < workspace.size(); i++) { + te_workspace.emplace_back(make_tensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte, + nullptr, nullptr, nullptr)); + } + + // For now, we only have multi-stream cublas backend. + nvte_multi_stream_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), + te_pre_gelu_out.data(), te_A.size(), transa, transb, grad, + te_workspace.data(), accumulate, use_split_accumulator, + math_sm_count, at::cuda::getCurrentCUDAStream()); +} diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp index 8c480e8343..9f31dba669 100644 --- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp +++ b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp @@ -305,6 +305,41 @@ std::vector te_grouped_gemm_ts( return D; } +at::Tensor te_grouped_gemm_single_output_ts( + std::vector A, std::vector A_scale_inverse, int64_t A_offset, + int64_t A_type, int64_t transa, std::vector B, at::Tensor B_scale_inverse, + int64_t B_offset, int64_t B_type, int64_t transb, std::vector m_splits, at::Tensor D, + int64_t D_offset, at::Tensor D_scale, int64_t D_type, at::Tensor D_amax, + std::vector bias, int64_t bias_type, std::vector pre_gelu_out, + int64_t grad, std::vector workspace, int64_t workspaceSize, int64_t accumulate, + int64_t use_split_accumulator) { + // cast inputs to types accepted by te_gemm + transformer_engine::DType A_type_arg = reverse_map_dtype(A_type); + bool transa_arg = static_cast(transa); + transformer_engine::DType B_type_arg = reverse_map_dtype(B_type); + bool transb_arg = static_cast(transb); + transformer_engine::DType D_type_arg = reverse_map_dtype(D_type); + transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type); + bool grad_arg = static_cast(grad); + size_t workspaceSize_arg = static_cast(workspaceSize); + bool accumulate_arg = static_cast(accumulate); + bool use_split_accumulator_arg = static_cast(use_split_accumulator); + + // Set an external SM Margin to all the GEMMs. + // This comes in handy when DP is overlapped with GEMMs + + const int device_id = at::cuda::current_device(); + const int sm_count = transformer_engine::cuda::sm_count(device_id); + int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); + + te_grouped_gemm_single_output(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B, + B_scale_inverse, B_offset, B_type_arg, transb_arg, m_splits, D, + D_offset, D_scale, D_type_arg, D_amax, bias, bias_type_arg, + pre_gelu_out, grad_arg, workspace, workspaceSize_arg, + accumulate_arg, use_split_accumulator_arg, num_math_sms); + return D; +} + at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, double eps, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor, @@ -371,6 +406,7 @@ TORCH_LIBRARY(tex_ts, m) { m.def("srelu_ts", &srelu_ts); m.def("te_gemm_ts", &te_gemm_ts); m.def("te_grouped_gemm_ts", &te_grouped_gemm_ts); + m.def("te_grouped_gemm_single_output_ts", &te_grouped_gemm_single_output_ts); m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts); m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts); m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts); diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index a91ff5c361..ca100392c7 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -42,6 +42,7 @@ from ..jit import no_torch_dynamo from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor +from ..export import is_in_onnx_export_mode __all__ = ["GroupedLinear"] @@ -102,10 +103,12 @@ def forward( inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats] inputmats = [] inputmats_t = [] + inputmat_scale_inv = None global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + inputmat_scale_inv = torch.empty([num_gemms], dtype=torch.float32, device=inp.device) if ( not fp8_meta["recipe"].override_linear_precision.wgrad and is_grad_enabled @@ -121,6 +124,7 @@ def forward( indices, # amax_indices indices, # scale_inv_indices fp8_dtype_forward, + scale_inv=inputmat_scale_inv, ) else: # FP8 input for forward @@ -130,9 +134,22 @@ def forward( fp8_meta["scaling_fwd"], _GEMM_INPUT + i, fp8_dtype_forward, + scale_inv=inputmat_scale_inv, ) for i in range(num_gemms) ] + + # Hack for ONNX export + # Note: ONNX models are represented as a graph of tensor + # operations, so the in-place scale-inv update doesn't fit + # very well. We work around this by making it look like + # the scale-inv tensor is initialized with a copy. + # Note: ONNX export expects FP8 scales can be represented + # with constant ops. However, copying into a buffer + # involves an expand op for array broadcasting. We work + # around this by filling the buffer instead. + if is_in_onnx_export_mode(): + inputmat_scale_inv.fill_(inputmat_scale_inv.item()) else: inputmats = inputmats_no_fp8 @@ -153,16 +170,17 @@ def forward( _ = fp8_grouped_gemm( [w._data for w in weights_fp8], - fp8_meta["scaling_fwd"].scale_inv, - _GEMM_WEIGHT, + [w._scale_inv for w in weights_fp8], + 0, # weight offset is 0 for the newly created _scale_inv fp8_dtype_forward, inputmats, - fp8_meta["scaling_fwd"].scale_inv, - _GEMM_INPUT, + inputmat_scale_inv, + 0, fp8_dtype_forward, - torch.split(out, m_splits), + [out], activation_dtype, get_multi_stream_cublas_workspace(), + m_splits=m_splits, bias=biases, use_bias=use_bias, use_split_accumulator=_2X_ACC_FPROP, @@ -230,7 +248,7 @@ def forward( t.activation_offloading = True ctx.save_for_backward( - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, + inputmat_scale_inv, *saved_inputmats, *saved_inputmats_t, *weights, @@ -270,7 +288,7 @@ def forward( def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: with torch.cuda.nvtx.range("_GroupedLinear_backward"): ( - fwd_scale_inverses, + inputmat_scale_inv, *saved_tensors, ) = ctx.saved_tensors inputmats = saved_tensors[: ctx.num_gemms] @@ -342,18 +360,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) fp8_grouped_gemm( [w.transpose_2d() for w in weights_fp8], - torch.cat( - [w._scale_inv for w in weights_fp8] - ), # avoiding torch.cat requires another interface + [w._scale_inv for w in weights_fp8], 0, # weight offset is 0 for the newly created _scale_inv weights_fp8[0]._fp8_dtype, grad_output_c, ctx.fp8_meta["scaling_bwd"].scale_inv, _GRAD_OUTPUT, fp8_dtype_backward, - torch.split(dgrad, ctx.m_splits), + [dgrad], ctx.activation_dtype, get_multi_stream_cublas_workspace(), + m_splits=ctx.m_splits, use_split_accumulator=_2X_ACC_DGRAD, ) else: @@ -396,8 +413,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inp._data if isinstance(inp, Float8Tensor) else inp for inp in inputmats_t ], - fwd_scale_inverses, - _GEMM_INPUT, + [inputmat_scale_inv], + 0, fp8_dtype_forward, grad_output_t, ctx.fp8_meta["scaling_bwd"].scale_inv, From 2a9845e1d93440d3c0f65427985e66208d09eff8 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Mon, 9 Sep 2024 11:34:45 -0700 Subject: [PATCH 30/50] Added Adobe analytics to the documentation (#1162) Signed-off-by: Przemyslaw Tredak --- docs/_templates/layout.html | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/docs/_templates/layout.html b/docs/_templates/layout.html index a68b4531e3..f94e526f57 100644 --- a/docs/_templates/layout.html +++ b/docs/_templates/layout.html @@ -1,4 +1,11 @@ {% extends "!layout.html" %} + + {% block extrahead %} + + + + {% endblock %} + {% block sidebartitle %} {{ super() }} - {%- if nvidia_analytics_id %} - - {%- endif %} + {% endblock %} + + {% block footer %} + + {% endblock %} From 40dda924a52866c3a5e9b56f1907b4a2602f2fac Mon Sep 17 00:00:00 2001 From: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Date: Mon, 9 Sep 2024 13:50:47 -0700 Subject: [PATCH 31/50] Add a context parallelism implementation with QKVO all-to-all (#1160) * clean code for CP function args Signed-off-by: Xiaowei Ren * add a placeholder for Ulysses implementation Signed-off-by: Xiaowei Ren * commit code change to CP+A2A Signed-off-by: Xiaowei Ren * finish the draft fwd implementation of Ulysses Signed-off-by: Xiaowei Ren * add draft bwd implementation of Ulysses Signed-off-by: Xiaowei Ren * make swa work with ulysses Signed-off-by: Xiaowei Ren * commit FP8 code for Ulysses Signed-off-by: Xiaowei Ren * fix qkv type in the bwd of FP8+CP Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * fix qkv_dtype of FP8+CP Signed-off-by: Xiaowei Ren * code refactoring Signed-off-by: Xiaowei Ren * minor code change Signed-off-by: Xiaowei Ren * config cp correction dtype of FP8+CP Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * code style change Signed-off-by: Xiaowei Ren * save chunk_ids Signed-off-by: Xiaowei Ren * try to make Ulysses A2A async Signed-off-by: Xiaowei Ren * make more a2a async Signed-off-by: Xiaowei Ren * fix a2a_outputs Signed-off-by: Xiaowei Ren * fix chunk_ids generation for A2A Signed-off-by: Xiaowei Ren * avoid code duplication of a2a before attn Signed-off-by: Xiaowei Ren * remove code duplication of a2a after attn Signed-off-by: Xiaowei Ren * add cp_stream in A2A implementation Signed-off-by: Xiaowei Ren * bug fix Signed-off-by: Xiaowei Ren * fix qkv of fp8_fwd + bf16_bwd Signed-off-by: Xiaowei Ren * fix kernel order in cp a2a communication Signed-off-by: Xiaowei Ren * code cleaning for CP a2a Signed-off-by: Xiaowei Ren * fix merging with main Signed-off-by: Xiaowei Ren * fix a2a communication order Signed-off-by: Xiaowei Ren * adjust sequence chunk reordering for a2a Signed-off-by: Xiaowei Ren * add docstring for A2A implementation Signed-off-by: Xiaowei Ren * change an assert info Signed-off-by: Xiaowei Ren * add unit tests of A2A implementation Signed-off-by: Xiaowei Ren * add more A2A unit test Signed-off-by: Xiaowei Ren * fix CP unit tests Signed-off-by: Xiaowei Ren * add more cp unit tests Signed-off-by: Xiaowei Ren * fix window size of no_mask Signed-off-by: Xiaowei Ren * fused attn does not support swa+no_mask Signed-off-by: Xiaowei Ren * change num_gqa_groups to 2 for A2A implementation Signed-off-by: Xiaowei Ren * function and variable renaming Signed-off-by: Xiaowei Ren * code cleaning for CP all-gather implementation Signed-off-by: Xiaowei Ren * some function renaming Signed-off-by: Xiaowei Ren * remove redundant code Signed-off-by: Xiaowei Ren * commit code change for kv all-gather implementation Signed-off-by: Xiaowei Ren * fix all-gather implementation Signed-off-by: Xiaowei Ren * add a window size check Signed-off-by: Xiaowei Ren * code cleaning Signed-off-by: Xiaowei Ren * add unit test of all_gather+no_mask Signed-off-by: Xiaowei Ren * fix all-gather cp implementation Signed-off-by: Xiaowei Ren * code cleaning Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * code format fix Signed-off-by: Xiaowei Ren * code format fix Signed-off-by: Xiaowei Ren * fix FP8 with A2A implementation Signed-off-by: Xiaowei Ren * add paper references to CP implementations with all-gather and all-to-all Signed-off-by: Xiaowei Ren * change pdf to abs Signed-off-by: Xiaowei Ren * elaborate cp_comm_type Signed-off-by: Xiaowei Ren * fix CP docstring Signed-off-by: Xiaowei Ren --------- Signed-off-by: Xiaowei Ren Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../fused_attn/test_fused_attn_with_cp.py | 110 +- transformer_engine/pytorch/attention.py | 1020 +++++++++++++---- transformer_engine/pytorch/transformer.py | 8 +- 3 files changed, 849 insertions(+), 289 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 82875e2791..d6358d1062 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -22,10 +22,16 @@ "cp_1_2": ModelConfig( 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) ), # MHA - "cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA - "cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA + "cp_1_3": ModelConfig( + 2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512) + ), # MHA + "cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA + "cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA "cp_2_2": ModelConfig( - 2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) + 2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) + ), # GQA + "cp_2_3": ModelConfig( + 2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512) ), # GQA } @@ -45,31 +51,32 @@ def get_bash_arguments(**kwargs): @pytest.mark.parametrize("dtype", ["bf16", "fp16"]) @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) -@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"]) +@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"]) def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): config = model_configs_flash_attn[model] + if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1): + pytest.skip("CP implementation with KV P2P does not support sliding window yet!") if cp_comm_type == "all_gather" and qkv_format == "thd": - pytest.skip( - f"CP implementation with KV all-gather does not support {qkv_format} format yet!" - ) - if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type: - pytest.skip( - f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask" - " type yet!" - ) + pytest.skip("CP implementation with KV all-gather does not support THD format yet!") if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": + pytest.skip("CP implementation with KV all-gather does not support bias yet!") + if cp_comm_type == "a2a" and qkv_format == "thd": + pytest.skip("CP implementation with QKVO A2A does not support THD format yet!") + if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias": + pytest.skip("CP implementation with QKVO A2A does not support bias yet!") + if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): pytest.skip( - f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias" - " type yet!" - ) - if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1): - pytest.skip( - f"CP implementation with KV P2P does not support window size {config.window_size} yet!" + f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" + f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" ) subprocess.run( get_bash_arguments( - dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FlashAttention" + dtype=dtype, + model=model, + qkv_format=qkv_format, + kernel_backend="FlashAttention", + cp_comm_type=cp_comm_type, ), check=True, ) @@ -81,10 +88,16 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA "cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA "cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA - "cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA - "cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA - "cp_2_2": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA - "cp_2_3": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA + "cp_1_4": ModelConfig( + 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) + ), # MHA + "cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA + "cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA + "cp_2_2": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA + "cp_2_3": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA + "cp_2_4": ModelConfig( + 2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) + ), # GQA } @@ -93,37 +106,27 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"]) @pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) -@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"]) +@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"]) def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): if qkv_format == "thd" and get_device_compute_capability() < (9, 0): - pytest.skip("THD format is only supported on sm90+.") + pytest.skip("THD format is only supported on sm90+!") if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0): - pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0") + pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") config = model_configs_fused_attn[model] if qkv_format == "thd" and config.num_heads != config.num_gqa_groups: - pytest.skip(f"{qkv_format} format does not support QGA/MQA yet!") + pytest.skip("THD format does not support QGA/MQA yet!") if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": - pytest.skip(f"{qkv_format} format does not support {config.attn_bias_type} bias type yet!") - if cp_comm_type == "all_gather" and qkv_format == "thd": - pytest.skip( - f"CP implementation with KV all-gather does not support {qkv_format} format yet!" - ) - if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type: - pytest.skip( - f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask" - " type yet!" - ) - if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": - pytest.skip( - f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias" - " type yet!" - ) - if config.window_size != (-1, 0) and config.window_size != (-1, -1): + pytest.skip("THD format does not support post_scale_bias yet!") + if qkv_format == "thd" and cp_comm_type == "all_gather": + pytest.skip("CP implementation with KV all-gather does not support THD format yet!") + if qkv_format == "thd" and cp_comm_type == "a2a": + pytest.skip("CP implementation with QKVO A2A does not support THD format yet!") + if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type != "a2a": pytest.skip( - "Fused attention does not support sliding window attention + context parallelism yet!" + "Sliding window attention only can be supported with the implementation of QKVO A2A!" ) - if cp_comm_type == "all_gather" and dtype == "fp8": + if dtype == "fp8" and cp_comm_type == "all_gather": pytest.skip( "CP implementation with KV all-gather does not support FP8 + context parallelism yet!" ) @@ -131,10 +134,25 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): pytest.skip("FP8 attention cannot work with THD format yet!") if dtype == "fp8" and config.attn_bias_type != "no_bias": pytest.skip("FP8 attention cannot work with bias yet!") + if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1): + pytest.skip("FP8 attention cannot work with sliding window yet!") + if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": + pytest.skip("CP implementation with KV all-gather does not support bias yet!") + if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias": + pytest.skip("CP implementation with QKVO A2A does not support bias yet!") + if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): + pytest.skip( + f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" + f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" + ) subprocess.run( get_bash_arguments( - dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention" + dtype=dtype, + model=model, + qkv_format=qkv_format, + kernel_backend="FusedAttention", + cp_comm_type=cp_comm_type, ), check=True, ) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 91c14899ec..f8ba46b2ea 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -614,12 +614,6 @@ def get_attention_backend( "with causal mask, no dropout, and qkv_format = bshd/sbhd" ) use_fused_attention = False - elif context_parallel: - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention " - "with context parallelism" - ) - use_fused_attention = False elif max_seqlen_q != max_seqlen_kv and attn_mask_type in [ "no_mask", "padding", @@ -1429,9 +1423,6 @@ def forward( cu_seqlens_q_padded, cu_seqlens_kv_padded, dropout_p, - cp_group, - cp_global_ranks, - cp_stream, softmax_scale, qkv_format, attn_mask_type, @@ -1441,6 +1432,9 @@ def forward( use_fused_attention, fp8, fp8_meta, + cp_group, + cp_global_ranks, + cp_stream, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -2946,10 +2940,10 @@ def backward(ctx, dout): None, None, None, + attn_dbias, None, None, None, - attn_dbias, None, None, None, @@ -2958,30 +2952,56 @@ def backward(ctx, dout): @torch.compile -def get_seq_chunk_ids_to_all_gathered_kv( - local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size_left, device +def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous): + """ + Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing. + To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks + before or after CP communications (e.g., all-gather, all-to-all). This function is to compute + sequence chunk ids for reordering. + """ + chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device) + if to_contiguous: + for rank in range(cp_size): + chunk_ids[rank] = 2 * rank + chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1 + else: + for rank in range(cp_size): + chunk_ids[2 * rank] = rank + chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1 + return chunk_ids + + +def get_kv_seq_info_after_all_gather( + local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size, causal ): - """Compute sequence chunk ids to the all-gathered KV.""" - seq_end_idx = (local_chunk_id + 1) * max_seqlen_kv - seq_start_idx = max(0, seq_end_idx - max_seqlen_q - window_size_left) - seqlen = seq_end_idx - seq_start_idx - num_chunks = (seqlen + max_seqlen_kv - 1) // max_seqlen_kv - chunk_ids = torch.arange( - local_chunk_id - num_chunks + 1, - local_chunk_id + 1, - dtype=torch.int32, - device=device, - ) - chunk_ids_to_all_gathered_kv = torch.where( - chunk_ids < cp_size, 2 * chunk_ids, 2 * (2 * cp_size - chunk_ids) - 1 - ) - return chunk_ids_to_all_gathered_kv + """Compute KV sequence index range and update window size after all-gather.""" + local_chunk_end_idx = (local_chunk_id + 1) * max_seqlen_kv + full_seq_end_idx = max_seqlen_kv * cp_size * 2 + + if window_size is None: + window_size = (-1, 0) if causal else (-1, -1) + + if window_size[1] == -1: + seq_end_idx = full_seq_end_idx + window_size_right = -1 + else: + seq_end_idx = min(full_seq_end_idx, local_chunk_end_idx + window_size[1]) + window_size_right = local_chunk_end_idx + window_size[1] - seq_end_idx + + if window_size[0] == -1: + seq_start_idx = 0 + window_size_left = -1 + else: + seq_start_idx = max(0, local_chunk_end_idx - max_seqlen_q - window_size[0]) + window_size_left = window_size[0] + seq_end_idx - local_chunk_end_idx + + return (seq_start_idx, seq_end_idx), (window_size_left, window_size_right) class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): """ - Attention implementation with context parallelism. - KV all-gather between CP ranks is exposed. + Attention implementation with context parallelism. KV all-gather between CP ranks is exposed. + Refer section 3.3.2 of `The Llama 3 Herd of Models `_. """ @staticmethod @@ -2992,14 +3012,10 @@ def forward( k, v, cu_seqlens_q, - cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, cu_seqlens_q_padded, - cu_seqlens_kv_padded, dropout_p, - cp_group, - cp_stream, softmax_scale, qkv_format, attn_mask_type, @@ -3008,6 +3024,8 @@ def forward( deterministic, use_fused_attention, window_size, + cp_group, + cp_stream, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -3017,10 +3035,9 @@ def forward( causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type - assert causal and not padding, f"{attn_mask_type} mask type is not supported!" + assert not padding, f"{attn_mask_type} mask type is not supported!" if use_fused_attention and causal and "bottom_right" not in attn_mask_type: attn_mask_type = attn_mask_type + "_bottom_right" - assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" assert ( @@ -3029,6 +3046,8 @@ def forward( fa_optional_forward_kwargs = {} if _flash_attn_2_4_plus: fa_optional_forward_kwargs["alibi_slopes"] = None + if _flash_attn_2_5_7_plus: + fa_optional_forward_kwargs["block_table"] = None assert qkv_format != "thd", f"{qkv_format} format is not supported!" qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format @@ -3041,31 +3060,35 @@ def forward( max_seqlen_q = max_seqlen_q // (2 * cp_size) max_seqlen_kv = max_seqlen_kv // (2 * cp_size) cu_seqlens_q = cu_seqlens_q // (2 * cp_size) - cu_seqlens_kv = cu_seqlens_kv // (2 * cp_size) cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size) - cu_seqlens_kv_padded = cu_seqlens_kv_padded // (2 * cp_size) - - if causal: - if qkv_format == "bshd": - # [b, s, np, hn] -> [b, 2, s//2, np, hn] - q = q.view(q.shape[0], 2, q.shape[1] // 2, *q.shape[2:]) - # [b, s, np, hn] -> [s, b, np, hn] - k, v = [x.transpose(0, 1).contiguous() for x in [k, v]] - elif qkv_format == "sbhd": - # [s, b, np, hn] -> [2, s//2, b, np, hn] - q = q.view(2, q.shape[0] // 2, *q.shape[1:]) - # create two streams to resolve wave quantization issue of Flash Attn in each step - flash_attn_streams = [torch.cuda.current_stream(), cp_stream] + # [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn] + q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) + # [b, s, np, hn] or [s, b, np, hn] -> [s, b, np, hn] + k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] + # [s, b, np, hn] -> [cp, s, b, np, hn] k_ag, _ = gather_along_first_dim(k, cp_group) v_ag, _ = gather_along_first_dim(v, cp_group) - cp_stream.wait_stream(torch.cuda.current_stream()) + + # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True) + k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) + v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) + # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] + k_ag = k_ag.view(-1, *k.shape[1:]) + v_ag = v_ag.view(-1, *v.shape[1:]) + cp_stream.wait_stream(torch.cuda.current_stream()) + + # create two streams to resolve wave quantization issue of Flash Attn in each step + flash_attn_streams = [torch.cuda.current_stream(), cp_stream] local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] - chunk_ids_to_kv_ag_per_step = [None, None] + kv_seq_range_per_step = [None, None] + window_size_per_step = [None, None] + cu_seqlens_kv_per_step = [None, None] out_per_step = [None, None] softmax_lse_per_step = [None, None] rng_states = [None, None] @@ -3074,53 +3097,36 @@ def forward( for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - chunk_ids_to_kv_ag = get_seq_chunk_ids_to_all_gathered_kv( - local_seq_chunk_ids[i], - cp_size, - max_seqlen_q, - max_seqlen_kv, - ( - max_seqlen_kv * cp_size * 2 - if (window_size is None or window_size[0] == -1) - else window_size[0] - ), - k.device, - ) - chunk_ids_to_kv_ag_per_step[i] = chunk_ids_to_kv_ag - num_kv_chunks = chunk_ids_to_kv_ag.numel() - if qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - q_ = q[:, i].contiguous() - # [num_kv_chunks, sq//2, b, np, hn] -> [b, num_kv_chunks*sq//2, np, hn] - k_ = ( - torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag) - .movedim(2, 0) - .contiguous() - .view(k.shape[1], -1, *k.shape[-2:]) - ) - v_ = ( - torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag) - .movedim(2, 0) - .contiguous() - .view(v.shape[1], -1, *v.shape[-2:]) - ) - elif qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] - q_ = q[i].contiguous() - # [num_kv_chunks, sq//2, b, np, hn] -> [num_kv_chunks*sq//2, b, np, hn] - k_ = torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag).view( - -1, *k.shape[-3:] - ) - v_ = torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag).view( - -1, *v.shape[-3:] + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + q_ = q.select(seq_dim, i).contiguous() + kv_seq_range_per_step[i], window_size_per_step[i] = ( + get_kv_seq_info_after_all_gather( + local_seq_chunk_ids[i], + cp_size, + max_seqlen_q, + max_seqlen_kv, + window_size, + causal, ) + ) + seq_start_idx, seq_end_idx = ( + kv_seq_range_per_step[i][0], + kv_seq_range_per_step[i][1], + ) + max_seqlen_kv_ = seq_end_idx - seq_start_idx + cu_seqlens_kv_per_step[i] = _get_full_cu_seqlens( + k.shape[1], max_seqlen_kv_, k.device + ) + k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] + # [s_range, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn] + k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] if use_fused_attention: out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd( is_training, max_seqlen_q, - max_seqlen_kv * num_kv_chunks, + max_seqlen_kv_, cu_seqlens_q, - cu_seqlens_kv * num_kv_chunks, + cu_seqlens_kv_per_step[i], q_, k_, v_, @@ -3133,8 +3139,8 @@ def forward( attn_bias_type=attn_bias_type, attn_bias=attn_bias, cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded * num_kv_chunks, - window_size=window_size, + cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], + window_size=window_size_per_step[i], ) else: q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]] @@ -3144,14 +3150,14 @@ def forward( k_, v_, cu_seqlens_q, - cu_seqlens_kv * num_kv_chunks, + cu_seqlens_kv_per_step[i], max_seqlen_q, - max_seqlen_kv * num_kv_chunks, + max_seqlen_kv_, dropout_p, softmax_scale, - causal=True, + causal=causal, return_softmax=False, - window_size=window_size, + window_size=window_size_per_step[i], **fa_optional_forward_kwargs, ) ) @@ -3159,9 +3165,9 @@ def forward( if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): if qkv_format == "bshd": - out[:, i - 1].copy_(out_per_step[i - 1].view_as(out[:, i - 1])) + out[:, i - 1].copy_(out_per_step[i - 1].view(out[:, i - 1].shape)) elif qkv_format == "sbhd": - out[i - 1].copy_(out_per_step[i - 1].view_as(out[i - 1])) + out[i - 1].copy_(out_per_step[i - 1].view(out[i - 1].shape)) torch.cuda.current_stream().wait_stream(cp_stream) @@ -3178,26 +3184,24 @@ def forward( k, v, cu_seqlens_q, - cu_seqlens_kv, cu_seqlens_q_padded, - cu_seqlens_kv_padded, - *chunk_ids_to_kv_ag_per_step, + *cu_seqlens_kv_per_step, *out_per_step, *softmax_lse_per_step, *rng_states, ) + ctx.kv_seq_range_per_step = kv_seq_range_per_step + ctx.window_size_per_step = window_size_per_step ctx.cp_group = cp_group ctx.cp_stream = cp_stream ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_kv = max_seqlen_kv ctx.softmax_scale = softmax_scale ctx.qkv_format = qkv_format - ctx.attn_mask_type = attn_mask_type ctx.attn_bias_type = attn_bias_type + ctx.attn_mask_type = attn_mask_type ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention - ctx.window_size = window_size return out @staticmethod @@ -3205,21 +3209,20 @@ def backward(ctx, dout): cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) - (q, k, v, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ( - ctx.saved_tensors[:7] - ) - chunk_ids_to_kv_ag_per_step = ctx.saved_tensors[7:9] - out_per_step = ctx.saved_tensors[9:11] - softmax_lse_per_step = ctx.saved_tensors[11:13] - rng_states = ctx.saved_tensors[13:15] + (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = ctx.saved_tensors[:5] + cu_seqlens_kv_per_step = ctx.saved_tensors[5:7] + out_per_step = ctx.saved_tensors[7:9] + softmax_lse_per_step = ctx.saved_tensors[9:11] + rng_states = ctx.saved_tensors[11:13] + kv_seq_range_per_step = ctx.kv_seq_range_per_step + window_size_per_step = ctx.window_size_per_step + seq_dim = ctx.qkv_format.index("s") qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format - dout = dout.view_as(q) + dout = dout.view(q.shape) dq = torch.empty_like(q) - dk = torch.zeros( - (2 * cp_size, k.shape[0] // 2, *k.shape[1:]), dtype=k.dtype, device=k.device - ) + dk = torch.zeros((k.shape[0] * cp_size, *k.shape[1:]), dtype=k.dtype, device=k.device) dv = torch.zeros_like(dk) dq_per_step = [None, None] dk_per_step = [None, None] @@ -3230,11 +3233,20 @@ def backward(ctx, dout): # synchronize dkv update across steps dkv_update_done = torch.cuda.Event() + # [s, b, np, hn] -> [cp, s, b, np, hn] k_ag, _ = gather_along_first_dim(k, ctx.cp_group) v_ag, _ = gather_along_first_dim(v, ctx.cp_group) - ctx.cp_stream.wait_stream(torch.cuda.current_stream()) + + # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True) + k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) + v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) + # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] + k_ag = k_ag.view(-1, *k.shape[1:]) + v_ag = v_ag.view(-1, *v.shape[1:]) + ctx.cp_stream.wait_stream(torch.cuda.current_stream()) local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] @@ -3247,66 +3259,46 @@ def backward(ctx, dout): for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - chunk_ids_to_kv_ag = chunk_ids_to_kv_ag_per_step[i] - num_kv_chunks = chunk_ids_to_kv_ag.numel() + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + q_ = q.select(seq_dim, i).contiguous() + seq_start_idx, seq_end_idx = ( + kv_seq_range_per_step[i][0], + kv_seq_range_per_step[i][1], + ) + max_seqlen_kv = seq_end_idx - seq_start_idx + k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] + # [cp*s, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn] + k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] out_ = out_per_step[i] - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - q_ = q[:, i].contiguous() - # [num_kv_chunks, sq//2, b, np, hn] -> [b, num_kv_chunks*sq//2, np, hn] - k_ = ( - torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag) - .movedim(2, 0) - .contiguous() - .view(k.shape[1], -1, *k.shape[-2:]) - ) - v_ = ( - torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag) - .movedim(2, 0) - .contiguous() - .view(v.shape[1], -1, *v.shape[-2:]) - ) - dout_ = dout[:, i].contiguous().view_as(out_) - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] - q_ = q[i].contiguous() - # [num_kv_chunks, sq//2, b, np, hn] -> [num_kv_chunks*sq//2, b, np, hn] - k_ = torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag).view( - -1, *k.shape[-3:] - ) - v_ = torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag).view( - -1, *v.shape[-3:] - ) - dout_ = dout[i].contiguous().view_as(out_) + dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) if ctx.use_fused_attention: - dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ - torch.empty_like(x) for x in [q_, k_, v_] - ] aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]] dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd( ctx.max_seqlen_q, - ctx.max_seqlen_kv * num_kv_chunks, + max_seqlen_kv, cu_seqlens_q, - cu_seqlens_kv * num_kv_chunks, + cu_seqlens_kv_per_step[i], q_, k_, v_, out_, dout_, TE_DType[q.dtype], - TE_DType[k.dtype], + TE_DType[dout.dtype], aux_ctx_tensors, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded * num_kv_chunks, + cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, qkv_layout=qkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, - window_size=ctx.window_size, + window_size=window_size_per_step[i], + deterministic=ctx.deterministic, ) else: + batch_size = k_.shape[0] q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]] dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ torch.empty_like(x) for x in [q_, k_, v_] @@ -3322,65 +3314,601 @@ def backward(ctx, dout): dk_per_step[i], dv_per_step[i], cu_seqlens_q, - cu_seqlens_kv * num_kv_chunks, + cu_seqlens_kv_per_step[i], ctx.max_seqlen_q, - ctx.max_seqlen_kv * num_kv_chunks, + max_seqlen_kv, ctx.dropout_p, ctx.softmax_scale, - True, - window_size=ctx.window_size, + "causal" in ctx.attn_mask_type, + window_size=window_size_per_step[i], rng_state=rng_states[i], **fa_optional_backward_kwargs, ) + # [b*sq//2, np, hn] -> [b, sq//2, np, hn] + dq_per_step[i] = dq_per_step[i].view(dq[:, i].shape) + # [b*s_range, np, hn] -> [b, s_range, np, hn] + dk_per_step[i], dv_per_step[i] = [ + x.view(batch_size, -1, *x.shape[-2:]) + for x in [dk_per_step[i], dv_per_step[i]] + ] if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): - chunk_ids_to_kv_ag = chunk_ids_to_kv_ag_per_step[i - 1] - num_kv_chunks = chunk_ids_to_kv_ag.numel() if ctx.qkv_format == "bshd": - dq[:, i - 1].copy_(dq_per_step[i - 1].view_as(dq[:, i - 1])) - dk_per_step[i - 1] = ( - dk_per_step[i - 1] - .view(k.shape[1], num_kv_chunks, -1, *k.shape[-2:]) - .movedim(0, 2) - .contiguous() - ) - dv_per_step[i - 1] = ( - dv_per_step[i - 1] - .view(v.shape[1], num_kv_chunks, -1, *v.shape[-2:]) - .movedim(0, 2) - .contiguous() - ) + dq[:, i - 1].copy_(dq_per_step[i - 1]) elif ctx.qkv_format == "sbhd": - dq[i - 1].copy_(dq_per_step[i - 1].view_as(dq[i - 1])) - dk_per_step[i - 1] = dk_per_step[i - 1].view( - num_kv_chunks, -1, *k.shape[-3:] - ) - dv_per_step[i - 1] = dv_per_step[i - 1].view( - num_kv_chunks, -1, *v.shape[-3:] - ) - + dq[i - 1].copy_(dq_per_step[i - 1]) + # [b, s_range, np, hn] or [s_range, b, np, hn] -> [s_range, b, np, hn] + dk_per_step[i - 1], dv_per_step[i - 1] = [ + x.movedim(seq_dim, 0).contiguous() + for x in [dk_per_step[i - 1], dv_per_step[i - 1]] + ] # wait until dkv update of last step is done if i > 1: flash_attn_streams[i - 1].wait_event(dkv_update_done) - dk.index_add_(0, chunk_ids_to_kv_ag, dk_per_step[i - 1]) - dv.index_add_(0, chunk_ids_to_kv_ag, dv_per_step[i - 1]) + seq_start_idx, seq_end_idx = ( + kv_seq_range_per_step[i - 1][0], + kv_seq_range_per_step[i - 1][1], + ) + dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1]) + dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1]) if i < len(local_seq_chunk_ids): flash_attn_streams[i - 1].record_event(dkv_update_done) torch.cuda.current_stream().wait_stream(ctx.cp_stream) + # [cp*s, b, np, hn] -> [cp*2, s//2, b, np, hn] + dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) + dv = dv.view(2 * cp_size, -1, *dv.shape[-3:]) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, dk.device, False) + dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag) + dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag) + # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] dk = dk.view(-1, *dk.shape[-3:]) dv = dv.view(-1, *dv.shape[-3:]) dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group) + dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :]) + dk = dk.movedim(0, seq_dim).contiguous() + dv = dv.movedim(0, seq_dim).contiguous() + + return ( + None, + dq, + dk, + dv, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +@torch.compile +def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn): + """Reorder sequence chunk for A2A communication.""" + if before_attn: + # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn] or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn] + x = x.movedim(0, seq_dim).contiguous() + # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :]) + # reorder the sequence chunks + x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a) + else: + # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.movedim(seq_dim, 0).contiguous() + # reorder the sequence chunks + x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a) + # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn] + x = x.view(cp_size, 2, *x.shape[1:]) + return x + + +def flash_attn_a2a_communicate( + a2a_inputs: Union[torch.Tensor, List[torch.Tensor]], + chunk_ids_for_a2a: torch.Tensor, + seq_dim: int, + cp_size: int, + cp_group: dist_group_type, + cp_stream: torch.cuda.Stream, + before_attn: bool, +) -> Union[torch.Tensor, List[torch.Tensor]]: + """A2A communication for context parallelism.""" + a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs + a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs) + if before_attn: + for i in range(len(a2a_inputs) + 2): + if 0 < i < len(a2a_inputs) + 1: + a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1]) + a2a_reqs[i - 1] = torch.distributed.all_to_all_single( + a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True + ) + if i > 1: + with torch.cuda.stream(cp_stream): + a2a_reqs[i - 2].wait() + x = a2a_outputs[i - 2] + # reorder the sequence chunks + x = reorder_seq_chunks_for_a2a( + x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn + ) + # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn] + a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) + if i < len(a2a_inputs): + x = a2a_inputs[i] + # [b, s, np, hn] -> [b, s, cp, np//cp, hn] or [s, b, np, hn] -> [s, b, cp, np//cp, hn] + x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1]) + # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn] + a2a_inputs[i] = x.movedim(-3, 0).contiguous() + else: + for i in range(len(a2a_inputs) + 2): + if 0 < i < len(a2a_inputs) + 1: + a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1]) + a2a_reqs[i - 1] = torch.distributed.all_to_all_single( + a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True + ) + if i < len(a2a_inputs): + x = a2a_inputs[i] + # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) + # reorder the sequence chunks + a2a_inputs[i] = reorder_seq_chunks_for_a2a( + x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn + ) + if i > 1: + with torch.cuda.stream(cp_stream): + a2a_reqs[i - 2].wait() + x = a2a_outputs[i - 2] + # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] + x = x.movedim(0, -3).movedim(0, seq_dim).contiguous() + # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] + a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1]) + torch.cuda.current_stream().wait_stream(cp_stream) + return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs + + +class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): + """ + Attention implementation with context parallelism. Like Ulysses, applying A2A to QKVO. + Refer the paper `DeepSpeed Ulysses `_. + """ + + @staticmethod + def forward( + ctx, + is_training, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + dropout_p, + softmax_scale, + qkv_format, + attn_mask_type, + attn_bias_type, + attn_bias, + deterministic, + use_fused_attention, + window_size, + fp8, + fp8_meta, + cp_group, + cp_stream, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + cp_size = get_distributed_world_size(cp_group) + + causal = "causal" in attn_mask_type + padding = "padding" in attn_mask_type + assert not padding, f"{attn_mask_type} mask type is not supported!" + assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" + assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" + assert ( + window_size == (-1, 0) + or window_size == (-1, -1) + or use_fused_attention + or _flash_attn_2_3_plus + ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" + fa_optional_forward_kwargs = {} + if _flash_attn_2_3_plus: + fa_optional_forward_kwargs["window_size"] = window_size + if _flash_attn_2_4_plus: + fa_optional_forward_kwargs["alibi_slopes"] = None + if _flash_attn_2_5_7_plus: + fa_optional_forward_kwargs["block_table"] = None + + assert ( + q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0 + ), "The number of attention heads needs to be divisible by CP size!" + + assert qkv_format != "thd", f"{qkv_format} format is not supported!" + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + + batch_dim = qkv_format.index("b") + seq_dim = qkv_format.index("s") + assert ( + q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 + ), "Sequence length per GPU needs to be divisible by 2!" + + if fp8: + if use_fused_attention: + fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + fused_attn_qkv_dtype = fp8_dtype_forward + fused_attn_backend = FusedAttnBackend["FP8"] + if fp8_meta["recipe"].fp8_mha: + assert ( + isinstance(q, Float8Tensor) + and isinstance(k, Float8Tensor) + and isinstance(v, Float8Tensor) + ), "q/k/v must be Float8Tensors for FP8 MHA!" + fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv + q_fp8, k_fp8, v_fp8 = q, k, v + q, k, v = q_fp8._data, k_fp8._data, v_fp8._data + elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q_f16, k_f16, v_f16 = q, k, v + q, k, v = [ + cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + for x in [q_f16, k_f16, v_f16] + ] + fp8_meta_kwargs = {} + fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv + fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV + fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv + fp8_meta_kwargs["d_scale_s_offset"] = META_S + fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale + fp8_meta_kwargs["q_scale_s_offset"] = META_S + fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale + fp8_meta_kwargs["q_scale_o_offset"] = META_O + fp8_meta_kwargs["amax_s"] = fp8_meta["scaling_fwd"].amax_history + fp8_meta_kwargs["amax_s_offset"] = META_S + fp8_meta_kwargs["amax_o"] = fp8_meta["scaling_fwd"].amax_history + fp8_meta_kwargs["amax_o_offset"] = META_O + else: + assert False, "FP8 is only supported with Fused Attention!" + else: + if use_fused_attention: + fp8_meta_kwargs = {} + fused_attn_qkv_dtype = TE_DType[q.dtype] + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, True) + q, k, v = flash_attn_a2a_communicate( + [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True + ) + + if fp8 and not fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q_f16, k_f16, v_f16 = q, k, v + q, k, v = [ + cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + for x in [q_f16, k_f16, v_f16] + ] + + batch_size = q.shape[batch_dim] + if use_fused_attention: + out, aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + q, + k, + v, + fused_attn_qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + window_size=window_size, + **fp8_meta_kwargs, + ) + else: + # [b, cp*s, np//cp, hn] -> [b*cp*s, np//cp, hn] + q, k, v = [x.view(-1, *x.shape[-2:]) for x in [q, k, v]] + ( + _, + _, + _, + _, + out, + softmax_lse, + _, + rng_state, + ) = _flash_attn_forward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + dropout_p, + softmax_scale, + causal=causal, + return_softmax=False, + **fa_optional_forward_kwargs, + ) + aux_ctx_tensors = [softmax_lse, rng_state] + # [b*cp*s, np//cp, hn] -> [b, cp*s, np//cp, hn] + out = out.view(batch_size, -1, *out.shape[-2:]) + + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, False) + out = flash_attn_a2a_communicate( + out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False + ) + + if use_fused_attention: + if qkv_format == "bshd": + # [b*s, np, hn] -> [b, s, np, hn] + out = out.view(batch_size, -1, *out.shape[-2:]) + elif qkv_format == "sbhd": + # [s*b, np, hn] -> [s, b, np, hn] + out = out.view(-1, batch_size, *out.shape[-2:]) + + if fp8: + if fp8_meta["recipe"].fp8_mha: + out_fp8 = Float8Tensor( + data=out, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=META_O, + fp8_dtype=fp8_dtype_forward, + dtype=q_fp8.dtype, + ) + out = out_fp8._data + out_ret = out_fp8 + else: + out_f16 = cast_from_fp8( + out, + fp8_meta["scaling_fwd"], + META_O, + fp8_dtype_forward, + TE_DType[q_f16.dtype], + ) + out_ret = out_f16 + else: + out_ret = out + + if fp8: + if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q_save, k_save, v_save, out_save = q, k, v, out + elif fp8_meta["recipe"].fp8_mha: + q_fp8, k_fp8, v_fp8 = [ + Float8Tensor( + data=x, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=META_QKV, + fp8_dtype=fp8_dtype_forward, + dtype=out_fp8.dtype, + ) + for x in [q, k, v] + ] + q_save, k_save, v_save, out_save = q_fp8, k_fp8, v_fp8, out_fp8 + else: + q_save, k_save, v_save, out_save = q_f16, k_f16, v_f16, out_f16 + else: + q_save, k_save, v_save, out_save = q, k, v, out + + if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone() + fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone() + else: + fp8_fwd_scales, fp8_fwd_scale_invs = None, None + + ctx.save_for_backward( + q_save, + k_save, + v_save, + out_save, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + fp8_fwd_scales, + fp8_fwd_scale_invs, + *aux_ctx_tensors, + ) + ctx.batch_size = batch_size + ctx.cp_group = cp_group + ctx.cp_stream = cp_stream + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_kv = max_seqlen_kv + ctx.softmax_scale = softmax_scale + ctx.qkv_format = qkv_format + ctx.attn_mask_type = attn_mask_type + ctx.attn_bias_type = attn_bias_type + ctx.deterministic = deterministic + ctx.window_size = window_size + ctx.use_fused_attention = use_fused_attention + ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ctx.fp8_meta = fp8_meta + return out_ret + + @staticmethod + def backward(ctx, dout): + cp_size = get_distributed_world_size(ctx.cp_group) + + q, k, v, out = ctx.saved_tensors[:4] + cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = ctx.saved_tensors[ + 4:8 + ] + fp8_fwd_scales, fp8_fwd_scale_invs = ctx.saved_tensors[8:10] + aux_ctx_tensors = ctx.saved_tensors[10:] + + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + causal = "causal" in ctx.attn_mask_type + seq_dim = ctx.qkv_format.index("s") + + if ctx.fp8: + if ctx.use_fused_attention: + fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) + fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) + fused_attn_qkv_dtype = fp8_dtype_forward + fused_attn_dqkv_dtype = fp8_dtype_backward + fused_attn_backend = FusedAttnBackend["FP8"] + if ctx.fp8_meta["recipe"].fp8_mha: + assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" + ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv + dout_fp8 = dout + dout = dout_fp8._data + else: + dout_f16 = dout + dout = cast_to_fp8( + dout_f16, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward + ) + fp8_meta_kwargs = {} + fp8_meta_kwargs["d_scale_qkv"] = fp8_fwd_scale_invs[META_QKV] + fp8_meta_kwargs["d_scale_s"] = fp8_fwd_scale_invs[META_S] + fp8_meta_kwargs["d_scale_o"] = fp8_fwd_scale_invs[META_O] + fp8_meta_kwargs["d_scale_do"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] + fp8_meta_kwargs["d_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP] + fp8_meta_kwargs["q_scale_s"] = fp8_fwd_scales[META_S] + fp8_meta_kwargs["q_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale[META_DP] + fp8_meta_kwargs["q_scale_dqkv"] = ctx.fp8_meta["scaling_bwd"].scale[META_DQKV] + fp8_meta_kwargs["amax_dp"] = ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP] + fp8_meta_kwargs["amax_dqkv"] = ctx.fp8_meta["scaling_bwd"].amax_history[0][ + META_DQKV + ] + else: + assert False, "FP8 is only supported with Fused Attention!" + else: + if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha: + assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" + q, k, v, out, dout = [x.from_float8(x.dtype) for x in [q, k, v, out, dout]] + if ctx.use_fused_attention: + fp8_meta_kwargs = {} + fused_attn_qkv_dtype = TE_DType[q.dtype] + fused_attn_dqkv_dtype = TE_DType[dout.dtype] + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + + if not ctx.use_fused_attention: + out = out.view(ctx.batch_size, -1, *out.shape[-2:]) + dout = dout.view(*out.shape) + + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, True) + out, dout = flash_attn_a2a_communicate( + [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True + ) + + fa_optional_backward_kwargs = {} + if _flash_attn_2_3_plus: + fa_optional_backward_kwargs["window_size"] = ctx.window_size + if _flash_attn_2_4_plus: + fa_optional_backward_kwargs["alibi_slopes"] = None + if _flash_attn_2_4_1_plus: + fa_optional_backward_kwargs["deterministic"] = ctx.deterministic + + if ctx.use_fused_attention: + dq, dk, dv, _ = fused_attn_bwd( + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + q, + k, + v, + out, + dout, + fused_attn_qkv_dtype, + fused_attn_dqkv_dtype, + aux_ctx_tensors, + fused_attn_backend, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + attn_scale=ctx.softmax_scale, + dropout=ctx.dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=ctx.attn_mask_type, + attn_bias_type=ctx.attn_bias_type, + window_size=ctx.window_size, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, + ) + else: + softmax_lse, rng_state = aux_ctx_tensors + out, dout = [x.view(-1, *x.shape[-2:]) for x in [out, dout]] + dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] + _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_kv, + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + ctx.dropout_p, + ctx.softmax_scale, + causal, + rng_state=rng_state, + **fa_optional_backward_kwargs, + ) + dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] + + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, False) + dq, dk, dv = flash_attn_a2a_communicate( + [dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, False + ) + if ctx.qkv_format == "bshd": - dq = dq.view(dq.shape[0], -1, *dq.shape[-2:]) - dk = dk.transpose(0, 1).contiguous() - dv = dv.transpose(0, 1).contiguous() + dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] elif ctx.qkv_format == "sbhd": - dq = dq.view(-1, *dq.shape[-3:]) + dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] + + if ctx.fp8: + if ctx.fp8_meta["recipe"].fp8_mha: + dq, dk, dv = [ + Float8Tensor( + data=x, + fp8_meta=ctx.fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=META_DQKV, + fp8_dtype=fp8_dtype_backward, + dtype=dout_fp8.dtype, + ) + for x in [dq, dk, dv] + ] + else: + dq, dk, dv = [ + cast_from_fp8( + x, + ctx.fp8_meta["scaling_bwd"], + META_DQKV, + fp8_dtype_backward, + TE_DType[dout_f16.dtype], + ) + for x in [dq, dk, dv] + ] return ( None, @@ -3404,6 +3932,9 @@ def backward(ctx, dout): None, None, None, + None, + None, + None, ) @@ -3465,57 +3996,44 @@ def attn_forward_func_with_cp( sliding_window_attn = ( window_size is not None and window_size != (-1, 0) and window_size != (-1, -1) ) + assert ( + not sliding_window_attn + or cp_comm_type == "a2a" + or (cp_comm_type == "all_gather" and not use_fused_attention) + ), "The context parallel running configs cannot support sliding window attetnion!" - if sliding_window_attn or cp_comm_type == "all_gather": - out = AttnFuncWithCPAndKVAllGather.apply( - is_training, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - dropout_p, - cp_group, - cp_stream, - softmax_scale, - qkv_format, - attn_mask_type, - attn_bias_type, - attn_bias, - deterministic, - use_fused_attention, - window_size, - ) - elif cp_comm_type == "p2p": - out = AttnFuncWithCPAndKVP2P.apply( - is_training, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - dropout_p, - cp_group, - cp_global_ranks, - cp_stream, - softmax_scale, - qkv_format, - attn_mask_type, - attn_bias_type, - attn_bias, - deterministic, - use_fused_attention, - fp8, - fp8_meta, - ) + args = [ + is_training, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + dropout_p, + softmax_scale, + qkv_format, + attn_mask_type, + attn_bias_type, + attn_bias, + deterministic, + use_fused_attention, + ] + + if cp_comm_type == "p2p": + args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream] + out = AttnFuncWithCPAndKVP2P.apply(*args) + elif cp_comm_type == "all_gather": + args.pop(5) + args.pop(8) + args += [window_size, cp_group, cp_stream] + out = AttnFuncWithCPAndKVAllGather.apply(*args) + elif cp_comm_type == "a2a": + args += [window_size, fp8, fp8_meta, cp_group, cp_stream] + out = AttnFuncWithCPAndQKVOA2A.apply(*args) else: raise ValueError(f"Unsupported communication type: {cp_comm_type}!") @@ -6416,7 +6934,13 @@ class DotProductAttention(TransformerEngineBaseModule): can overlap two flash attention kernels. cp_comm_type : str inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather". + Can be "p2p" or "all_gather" or "a2a". + "p2p": Exchange KV chunks with P2P communications in ring topology. + P2P is async and can be overlapped with attention compute. + "all_gather": All-gather to get full sequence of KV before attention. + The all-gather is not async, and cannot be overlapped. + "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP + group, and gather to get full sequence of QKV. """ def __init__( @@ -6608,7 +7132,13 @@ def set_context_parallel_group( cuda stream for context parallel execution. cp_comm_type : str inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather". + Can be "p2p" or "all_gather" or "a2a". + "p2p": Exchange KV chunks with P2P communications in ring topology. + P2P is async and can be overlapped with attention compute. + "all_gather": All-gather to get full sequence of KV before attention. + The all-gather is not async, and cannot be overlapped. + "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP + group, and gather to get full sequence of QKV. """ self.cp_group = cp_group self.cp_global_ranks = cp_global_ranks @@ -7633,7 +8163,13 @@ def set_context_parallel_group( cuda stream for context parallel execution. cp_comm_type : str inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather". + Can be "p2p" or "all_gather" or "a2a". + "p2p": Exchange KV chunks with P2P communications in ring topology. + P2P is async and can be overlapped with attention compute. + "all_gather": All-gather to get full sequence of KV before attention. + The all-gather is not async, and cannot be overlapped. + "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP + group, and gather to get full sequence of QKV. """ # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(self.modules()): diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index bd6e27594d..958c7019ba 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -503,7 +503,13 @@ def set_context_parallel_group( cuda stream for context parallel execution. cp_comm_type : str inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather". + Can be "p2p" or "all_gather" or "a2a". + "p2p": Exchange KV chunks with P2P communications in ring topology. + P2P is async and can be overlapped with attention compute. + "all_gather": All-gather to get full sequence of KV before attention. + The all-gather is not async, and cannot be overlapped. + "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP + group, and gather to get full sequence of QKV. """ # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(self.modules()): From 2d57db8bcc5cf5562e726e978c875877c478a139 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 11 Sep 2024 06:12:03 -0700 Subject: [PATCH 32/50] [PyTorch] Proxy class for low-precision tensor (#1127) * Add base class for tensor proxies Signed-off-by: Tim Moon * Move tensor detaching logic to tensor proxy base class Signed-off-by: Tim Moon * Use Python wrappers to PyTorch extensions Signed-off-by: Tim Moon * Include transpose caching logic in proxy encode function Signed-off-by: Tim Moon * Debug dimension mismatch with amax history Signed-off-by: Tim Moon * Move dequantize logic to proxy_decode func Signed-off-by: Tim Moon * Rename to "QuantizedTensor" Signed-off-by: Tim Moon * Rename "proxy_detach" to "detach" Signed-off-by: Tim Moon * Include transpose cache in detach and clone funcs Signed-off-by: Tim Moon * Fix linter warnings Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update FP8 workspaces with QuantizedTensor functions Signed-off-by: Tim Moon * Move logic for FP8 transpose cache in FP8 workspaces to base class Signed-off-by: Tim Moon * Remove cast-transpose logic from linear op Signed-off-by: Tim Moon * Remove unnecessary args for Float8Tensor when using FP8 attr dict Signed-off-by: Tim Moon * Remove __torch_function__ to QuantizedTensor Signed-off-by: Tim Moon * Fix linter warnings Signed-off-by: Tim Moon * Update tests/pytorch/test_float8tensor.py Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Debug FP8 transpose test Signed-off-by: Tim Moon * Debug cast functions Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani --- .github/workflows/lint.yml | 2 +- tests/pytorch/test_float8tensor.py | 9 +- tests/pytorch/test_fusible_ops.py | 5 +- .../pytorch/cpp_extensions/_common.py | 6 +- .../pytorch/cpp_extensions/cast.py | 5 +- transformer_engine/pytorch/float8_tensor.py | 1001 +---------------- transformer_engine/pytorch/module/base.py | 96 +- .../pytorch/module/grouped_linear.py | 21 +- .../pytorch/module/layernorm_linear.py | 34 +- .../pytorch/module/layernorm_mlp.py | 19 +- transformer_engine/pytorch/module/linear.py | 30 +- transformer_engine/pytorch/ops/_common.py | 62 +- .../pytorch/ops/basic/all_reduce.py | 11 +- .../pytorch/ops/basic/basic_linear.py | 74 +- .../pytorch/ops/basic/reduce_scatter.py | 17 +- transformer_engine/pytorch/tensor/__init__.py | 8 + .../pytorch/tensor/float8_tensor.py | 972 ++++++++++++++++ .../pytorch/tensor/quantized_tensor.py | 172 +++ transformer_engine/pytorch/utils.py | 50 + 19 files changed, 1352 insertions(+), 1242 deletions(-) create mode 100644 transformer_engine/pytorch/tensor/__init__.py create mode 100644 transformer_engine/pytorch/tensor/float8_tensor.py create mode 100644 transformer_engine/pytorch/tensor/quantized_tensor.py diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index a4fdbdfdfd..d2bd865a8f 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -31,7 +31,7 @@ jobs: run: | sudo apt-get update sudo apt-get install pip -y - pip install torch + pip install torch numpy export PYTHON_ONLY=1 export TE_PATH=. bash ./qa/L0_pytorch_lint/test.sh diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 0ea0319771..fd204f58c4 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -293,7 +293,7 @@ def test_transpose( with pytest.raises(AssertionError): torch.testing.assert_close(x_fp8_t, x, **tols) - # Caching test. + # Caching test assert x_fp8._transpose_invalid, "Transpose cache must be invalid when not caching." x_fp8 += 0.5 x = x_fp8.from_float8() @@ -302,14 +302,13 @@ def test_transpose( torch.testing.assert_close(x_fp8_t, x_t, **tols) assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly." - # Inplace update test. + # Inplace update test x_fp8 += 0.5 - assert x_fp8._transpose_invalid, "Transpose cache not invalidated properly." + assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly." x = x_fp8.from_float8() - x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(fill_cache=True)) + x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8._transpose) x_t = x.transpose(0, 1) torch.testing.assert_close(x_fp8_t, x_t, **tols) - assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly." def test_serialization( self, diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 3523e1cda5..e97dfe1efd 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -88,10 +88,7 @@ def make_reference_and_test_tensors( ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) test = ref.to(device=test_device, dtype=test_dtype) if test_is_fp8: - test = Float8Tensor.to_float8(test) - test._transpose = test._data.reshape(-1, test.size(-1)).transpose(0, 1) - test._transpose = test._transpose.contiguous() - test._transpose_invalid = False + test = Float8Tensor.to_float8(test, with_transpose_cache=True) elif test.data_ptr() == ref.data_ptr(): test = test.clone() ref.copy_(test) diff --git a/transformer_engine/pytorch/cpp_extensions/_common.py b/transformer_engine/pytorch/cpp_extensions/_common.py index 6ab7d95138..b9d7288dfa 100644 --- a/transformer_engine/pytorch/cpp_extensions/_common.py +++ b/transformer_engine/pytorch/cpp_extensions/_common.py @@ -68,13 +68,13 @@ def canonicalize_fp8_scales( # Force offsets to be the same if needed if not allow_multiple_offsets and not scale_offset == amax_offset == scale_inv_offset: if scale_offset != 0: - scale = scale[scale_offset] + scale = scale[scale_offset:] scale_offset = 0 if amax_offset != 0: - amax = amax[0][amax_offset] + amax = amax[:, amax_offset:] amax_offset = 0 if scale_inv_offset != 0: - scale_inv = scale_inv[scale_inv_offset] + scale_inv = scale_inv[scale_inv_offset:] scale_inv_offset = 0 # Pack tensors and offsets into dicts diff --git a/transformer_engine/pytorch/cpp_extensions/cast.py b/transformer_engine/pytorch/cpp_extensions/cast.py index 0c78a65a6c..cd3c01c785 100644 --- a/transformer_engine/pytorch/cpp_extensions/cast.py +++ b/transformer_engine/pytorch/cpp_extensions/cast.py @@ -8,7 +8,7 @@ import torch import transformer_engine_torch as tex -from ._common import canonicalize_fp8_scales, empty_tensor +from ._common import canonicalize_fp8_scales __all__ = ["cast_to_fp8", "cast_from_fp8"] @@ -81,8 +81,7 @@ def cast_from_fp8( # Construct empty tensors if needed if scale_inv is None: - scale_inv = empty_tensor() - scale_inv_offset = 0 + raise ValueError("Did not provide either `scale_inv` or `fp8_meta_tensor`") # Launch FP8 cast kernel return torch.ops.tex_ts.cast_from_fp8_ts( diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index d531979868..c3d8709925 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -3,1004 +3,7 @@ # See LICENSE for license information. """Tensor class with FP8 data""" -from __future__ import annotations -from typing import Any, Dict, Optional, Tuple, Union -import warnings -import torch -from torch.utils._pytree import tree_map -import transformer_engine_torch as tex +from .tensor import Float8Tensor -from .constants import TE_DType -from .cpp_extensions import fp8_cast_transpose_fused -from .fp8 import FP8GlobalStateManager - -aten = torch.ops.aten -c10d = torch.ops.c10d -updated_fp8_params = {} - - -def _make_fp8_attr_property_funcs(name: str) -> Any: - """Make accessors for an FP8 attribute - - We store FP8 attributes in a dictionary so we can share them - between tensors with the same data, e.g. detached tensors. For - convenience, we also expose them as property attributes. This - function creates the accessors for property attributes. - - Parameters - ---------- - name: str - Key in dictionary of FP8 attributes - - """ - - def get_func(self) -> Any: - return self._fp8_attrs[name] - - def set_func(self, value: Any) -> None: - self._fp8_attrs[name] = value - - def del_func(self) -> None: - del self._fp8_attrs[name] - - return dict(fget=get_func, fset=set_func, fdel=del_func) - - -class _FromFloat8Func(torch.autograd.Function): - """Cast from FP8 to other dtype""" - - @staticmethod - def forward( - _ctx: torch.autograd.function.FunctionCtx, # unused - tensor: Float8Tensor, - dtype: Optional[torch.dtype] = None, - ) -> torch.Tensor: - if dtype is None: - dtype = tensor.dtype - data = tensor._data.contiguous().view(1, -1).detach() - out = tex.cast_from_fp8( - data, - tensor._scale_inv, - tensor._fp8_dtype, - TE_DType[dtype], - ) - out = out.view(tensor.size()) - return out - - @staticmethod - def backward( - _ctx: torch.autograd.function.FunctionCtx, # unused - grad: torch.Tensor, - ) -> Tuple[Optional[torch.Tensor], ...]: - # Assume that we want gradients in full precision - return grad, None - - -def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None: - """Amax scale and update when there is at least 1 trainable FP8 parameter.""" - param_id = id(param._data) - - if param_id not in FP8GlobalStateManager.fp8_param_to_autocast: - return - - autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id] - - if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params: - return - - if autocast_key in updated_fp8_params: - updated_fp8_params[autocast_key].add(param_id) - else: - updated_fp8_params[autocast_key] = {param_id} - - current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key] - # All FP8 trainable parameters have been updated. - if updated_fp8_params[autocast_key] == current_fp8_params_set: - FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True, fp8_weights=True) - del updated_fp8_params[autocast_key] - - -class _ToFloat8Func(torch.autograd.Function): - """Cast to FP8 from other dtype""" - - @staticmethod - def forward( - _ctx: torch.autograd.function.FunctionCtx, # unused - tensor: torch.Tensor, - fp8_meta: Optional[Dict[str, Any]] = None, - fp8_meta_forward: bool = True, - fp8_meta_index: Optional[int] = None, - fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, - ) -> Float8Tensor: - - # Extract data from FP8 meta tensors if provided - if fp8_meta is not None: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=fp8_meta_forward, - ) - if fp8_meta_index is None: - raise ValueError( - "To initialize Float8Tensor with FP8 meta tensors, " - "the FP8 meta tensor index must also be provided" - ) - if scale is None: - scale = fp8_meta[fp8_meta_key].scale[fp8_meta_index] - if amax is None: - amax = fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] - - # Check input tensor - tensor = tensor.contiguous().cuda().detach() - if tensor.dtype not in (torch.float32, torch.bfloat16, torch.float16): - tensor = tensor.float() - - # Check scale - if not isinstance(scale, torch.Tensor): - if scale is None: - scale = 1 - scale = torch.full( - [1], - scale, - dtype=torch.float32, - device=tensor.device, - ) - if scale.numel() != 1: - raise ValueError("Attempted to initialize Float8Tensor with invalid scale tensor") - scale = scale.to(device=tensor.device, dtype=torch.float32) - - # Check scale-inverse - if scale_inv is None: - scale_inv = torch.empty([1], dtype=torch.float32, device=tensor.device) - else: - scale_inv = scale_inv.to(device=tensor.device, dtype=torch.float32) - - # Check amax - if amax is None: - amax = torch.empty_like(scale) - if not (amax.numel() == 1 and amax.is_cuda and amax.dtype == torch.float32): - raise ValueError("Attempted to initialize Float8Tensor with invalid amax tensor") - - # Cast data to FP8 - data = tex.cast_to_fp8( - tensor.view(1, -1), - scale, - amax, - scale_inv, - fp8_dtype, - ) - data = data.view(tensor.size()) - - # Construct FP8 tensor - return Float8Tensor( - data=data, - fp8_meta=fp8_meta, - fp8_meta_forward=fp8_meta_forward, - fp8_meta_index=fp8_meta_index, - fp8_dtype=fp8_dtype, - fp8_scale_inv=scale_inv, - dtype=tensor.dtype, - ) - - @staticmethod - def backward( - _ctx: torch.autograd.function.FunctionCtx, # unused - grad: torch.Tensor, - ) -> Tuple[Optional[torch.Tensor], ...]: - # Assume that we want gradients in full precision - return grad, None, None, None, None, None, None, None - - -class _IdentityFunc(torch.autograd.Function): - """Identity function - - If constructor keyword-arguments are provided, then construct a - new Float8Tensor using the provided tensor's attributes. - - """ - - @staticmethod - def forward( - ctx, - tensor: Float8Tensor, - init_kwargs: Optional[Dict[str, Any]] = None, - ) -> torch.Tensor: - - # Return input tensor if constructor kwargs are not provided - ctx.input_dtype = tensor.dtype - if init_kwargs is None: - return tensor - - # Construct new tensor if constructor kwargs are provided - default_kwargs = dict( - data=tensor._data, - fp8_meta=tensor._fp8_meta, - fp8_meta_forward=tensor._fp8_meta_forward, - fp8_meta_index=tensor._fp8_meta_index, - fp8_dtype=tensor._fp8_dtype, - fp8_scale_inv=tensor._scale_inv, - dtype=tensor.dtype, - ) - for key, val in default_kwargs.items(): - if key not in init_kwargs: - init_kwargs[key] = val - return Float8Tensor(**init_kwargs) - - @staticmethod - def backward(ctx, grad): - return grad.to(ctx.input_dtype), None - - -class _ViewFunc(torch.autograd.Function): - """View function - - View the Float8Tensor using the provided shape. - - """ - - @staticmethod - def forward( - ctx, - tensor: torch.Tensor, - shape: Tuple[int] = None, - ) -> torch.Tensor: - - # Return input tensor if shape is not provided - ctx.shape = tensor.shape - if shape is None: - return tensor - - # Construct new tensor if shape is provided - if isinstance(tensor, Float8Tensor): - return Float8Tensor.make_like( - tensor, - data=tensor._data.view(*shape), - ) - return tensor.view(*shape) - - @staticmethod - def backward( - ctx, - grad: torch.Tensor, - ) -> Tuple[Union[torch.Tensor, None], ...]: - - if isinstance(grad, Float8Tensor): - dgrad = Float8Tensor.make_like( - grad, - data=grad._data.view(ctx.shape), - ) - return dgrad, None - return grad.view(ctx.shape), None - - -class _ReshapeFunc(torch.autograd.Function): - """Reshape function - - Reshape the Float8Tensor using the provided shape. - - """ - - @staticmethod - def forward( - ctx, - tensor: torch.Tensor, - shape: Tuple[int] = None, - ) -> torch.Tensor: - - # Return input tensor if shape is not provided - ctx.shape = tensor.shape - if shape is None: - return tensor - - # Construct new tensor if shape is provided - if isinstance(tensor, Float8Tensor): - return Float8Tensor.make_like( - tensor, - data=tensor._data.reshape(*shape), - ) - return tensor.reshape(*shape) - - @staticmethod - def backward( - ctx, - grad: torch.Tensor, - ) -> Tuple[Union[torch.Tensor, None], ...]: - - if isinstance(grad, Float8Tensor): - dgrad = Float8Tensor.make_like( - grad, - data=grad._data.reshape(ctx.shape), - ) - return dgrad, None - return grad.reshape(ctx.shape), None - - -class Float8Tensor(torch.Tensor): - """Experimental tensor class with FP8 data - - The tensor presents as having a standard, higher-precision dtype, - but the data itself is (scaled) FP8. For most tensor operations, - the data will be cast to the nominal dtype before performing the - operation. - - Parameters - ---------- - data: torch.Tensor - Raw FP8 data in a uint8 tensor - fp8_attrs: dict, optional - FP8 metadata, primarily managed by Float8Tensor. If - provided, all other FP8 configuration is ignored. - fp8_meta: dict, optional - FP8 metadata object, primarily managed by TE modules. - fp8_meta_forward: bool, default = `True` - Whether to access the FP8 metadata for the - forward pass. Ignored if fp8_meta is not - provided. - fp8_meta_index: int, optional - Index to access in FP8 meta tensors. Required if - fp8_meta is provided and otherwise ignored. - fp8_dtype: transformer_engine_torch.DType, tex.DType.kFloat8E4M3 - FP8 format. - fp8_scale_inv: torch.Tensor - Reciprocal of the scaling factor applied when - casting to FP8, i.e. the scaling factor that must - be applied when casting from FP8 to higher - precision. Can be inferred from fp8_meta if - provided. - dtype: torch.dtype, default = torch.float32 - Nominal tensor datatype. - - """ - - def __new__( - cls, - *, - data: torch.Tensor, - fp8_attrs: Optional[Dict[str, Any]] = None, - fp8_meta: Optional[Dict[str, Any]] = None, - fp8_meta_forward: bool = True, - fp8_meta_index: Optional[int] = None, - fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, - fp8_scale_inv: Optional[torch.Tensor] = None, - dtype: torch.dtype = torch.float32, - ): - - # Check that data buffer is valid - if data.element_size() != 1: - raise ValueError( - f"Float8Tensor requires data buffer with 8-bit dtype (got dtype={data.dtype})" - ) - if data.requires_grad: - raise ValueError("Float8Tensor requires non-differentiable data buffer") - if not data.is_cuda: - data = data.cuda() - - # Initialize tensor object - self = torch.Tensor._make_wrapper_subclass( - cls, - data.size(), - strides=data.stride(), - storage_offset=data.storage_offset(), - dtype=dtype, - layout=data.layout, - requires_grad=data.requires_grad, - device=data.device, - ) - self._data: torch.Tensor = data - - # Initialize dict of class attributes - # Note: We store FP8 attributes in a dictionary so we can - # share them between tensors with the same data, e.g. detached - # tensors. - self._fp8_attrs: dict = {} - if fp8_attrs is not None: - self._fp8_attrs = fp8_attrs - return self - - # FP8 meta tensors - if fp8_meta is not None and fp8_meta_index is None: - raise ValueError( - "To initialize Float8Tensor with FP8 meta tensors, " - "the FP8 meta tensor index must also be provided" - ) - self._fp8_meta: Optional[Dict[str, Any]] = fp8_meta - self._fp8_meta_forward: bool = fp8_meta_forward - self._fp8_meta_index: Optional[int] = fp8_meta_index - - # FP8 dtype - assert fp8_dtype in ( - tex.DType.kFloat8E4M3, - tex.DType.kFloat8E5M2, - ), f"Unsupported fp8_dtype {fp8_dtype}." - self._fp8_dtype: tex.DType = fp8_dtype - - # Transposed version of `_data`. - self._transpose: Optional[Float8Tensor] = None - self._transpose_invalid: bool = True - - # FP8 scale-inverse - if fp8_scale_inv is None and self._fp8_meta is not None: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=self._fp8_meta_forward, - ) - fp8_scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index] - fp8_scale_inv = fp8_scale_inv.detach().view(1).clone() - if fp8_scale_inv is None: - raise ValueError( - "Attempted to initialize Float8Tensor without specifying scale-inverse" - ) - if not isinstance(fp8_scale_inv, torch.Tensor): - fp8_scale_inv = torch.full( - [1], - fp8_scale_inv, - dtype=torch.float32, - device=self._data.device, - ) - if fp8_scale_inv.numel() != 1: - raise ValueError( - "Attempted to initialize Float8Tensor with invalid scale-inverse tensor" - ) - if fp8_scale_inv.dim() != 1: - fp8_scale_inv = fp8_scale_inv.reshape(1) - if fp8_scale_inv.device != self._data.device or fp8_scale_inv.dtype != torch.float32: - fp8_scale_inv = fp8_scale_inv.to( - device=self._data.device, - dtype=torch.float32, - ) - self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv - - return self - - @classmethod - def make_like( - cls, - tensor: Float8Tensor, - *, - data: torch.Tensor, - fp8_attrs: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> Float8Tensor: - """Use attributes of a Float8Tensor to create another Float8Tensor - - See constructor for list of keyword arguments. - - """ - default_kwargs = dict( - fp8_meta=tensor._fp8_meta, - fp8_meta_forward=tensor._fp8_meta_forward, - fp8_meta_index=tensor._fp8_meta_index, - fp8_dtype=tensor._fp8_dtype, - fp8_scale_inv=tensor._scale_inv, - dtype=tensor.dtype, - ) - for key, val in default_kwargs.items(): - if key not in kwargs: - kwargs[key] = val - return Float8Tensor(data=data, fp8_attrs=fp8_attrs, **kwargs) - - def __repr__(self): - return ( - "Float8Tensor(" - f"fp8_dtype={self._fp8_dtype}, " - f"scale_inv={self._scale_inv.item()}, " - f"data={self.from_float8(dtype=self.dtype)}" - ")" - ) - - def from_float8(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: - """ - Construct plain PyTorch tensor from Float8Tensor - - By default the resulting tensor's dtype is the - Float8Tensor's nominal dtype. - """ - return _FromFloat8Func.apply(self, dtype) - - @classmethod - def to_float8( - cls, - tensor: torch.Tensor, - *, - fp8_meta: Optional[Dict[str, Any]] = None, - fp8_meta_forward: bool = True, - fp8_meta_index: Optional[int] = None, - fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, - ): - """Construct Float8Tensor from plain PyTorch tensor""" - return _ToFloat8Func.apply( - tensor, - fp8_meta, - fp8_meta_forward, - fp8_meta_index, - fp8_dtype, - scale, - amax, - scale_inv, - ) - - def float(self) -> torch.Tensor: - return self.from_float8(dtype=torch.float32) - - def bfloat16(self) -> torch.Tensor: - return self.from_float8(dtype=torch.bfloat16) - - def half(self) -> torch.Tensor: - return self.from_float8(dtype=torch.float16) - - def cpu(self) -> torch.Tensor: - return self.from_float8().cpu() - - def clone(self) -> Float8Tensor: - return _IdentityFunc.apply(self, {"data": self._data.detach().clone()}) - - def view(self, *shape: Tuple[int]) -> Float8Tensor: - return _ViewFunc.apply(self, shape) - - def reshape(self, *shape: Tuple[int]) -> Float8Tensor: - return _ReshapeFunc.apply(self, shape) - - def expand_as(self, other: torch.Tensor): - if other is self: - # Note: expand_as is hackily used to create dummy autograd nodes - # and access the backward graph (see - # https://github.com/pytorch/pytorch/blob/238fb660851268f44ff88127887041fea352fe48/torch/nn/parallel/distributed.py#L1026). - # We equally hackily add a dummy function to handle this - # case. - return _IdentityFunc.apply(self) - return super().expand_as(other) - - def contiguous( - self, - *, - memory_format: torch.memory_format = torch.contiguous_format, - ) -> Float8Tensor: - """Returns tensor with data in provided memory format - - Returns `self` if data is already in correct memory format. - - """ - if self._data.is_contiguous(memory_format=memory_format): - return self - return _IdentityFunc.apply( - self, - {"data": self._data.detach().contiguous(memory_format=memory_format)}, - ) - - def transpose_2d( - self, - *, - force_compute: bool = False, - fill_cache: bool = False, - noop_flag: Optional[torch.Tensor] = None, - cache: Optional[bool] = None, - ) -> torch.Tensor: - """ - 2D transpose with caching support. - - Parameters - ---------- - force_compute: bool, default = `False` - Force computation of transpose. Otherwise use - cached values, if possible. - fill_cache: bool, default = `False` - Cache output tensor for future function calls. - noop_flag: torch.Tensor, optional - float32 flag indicating whether to avoid updating - cached values, if possible. - cache: bool, deprecated - - """ - assert self.dim() == 2, f"{self.dim()}-D transpose not supported." - - # Handle deprecated cache kwarg - if cache is not None: - msg = ( - "cache kwarg for Float8Tensor.transpose_2d is deprecated, " - "please use force_compute and fill_cache instead" - ) - warnings.warn(msg, DeprecationWarning) - if cache: - force_compute = False - fill_cache = True - else: - force_compute = True - fill_cache = False - - # Need to compute transpose if cache is invalid - need_compute = force_compute - if self._transpose is None: - need_compute = True - elif self._transpose_invalid: - need_compute = True - - # Need to apply transpose kernel if noop flag is applied - if noop_flag is not None: - need_compute = True - - # Return cached transpose if possible - if not need_compute: - return self._transpose - - # Allocate output if needed - data = self._data.contiguous().reshape(-1, self.size(-1)) - out = self._transpose - if out is None: - out = torch.empty( - (data.size(1), data.size(0)), - dtype=torch.uint8, - device=data.device, - ) - noop_flag = None - else: - self._transpose_invalid = False - - # Apply transpose kernel - fp8_dtype = self._fp8_dtype - if noop_flag is None: - tex.fp8_transpose_noalloc(data, out, fp8_dtype) - else: - noop_flag = noop_flag.to(dtype=torch.float32, device=data.device) - tex.fp8_transpose_noalloc_noop(data, out, noop_flag, fp8_dtype) - - # Fill cache if needed - if fill_cache: - self._transpose = out - self._transpose_invalid = False - - return out - - @torch.no_grad() - def cast_transpose_( - self, - tensor: torch.Tensor, - noop_flag: Optional[torch.Tensor] = None, - ) -> None: - """Cast from tensor and populate transpose cache - - Only supported for 2D tensors. - - Parameters - ---------- - tensor: torch.Tensor - Tensor to copy from. Must have same dimensions as - destination tensor. - noop_flag: torch.Tensor, optional - float32 flag indicating whether to avoid updating - destination tensor. - - """ - - # Make sure tensor is in expected format - data = self._data - if ( - tensor.device != data.device - or tensor.dtype not in (torch.float32, torch.float16, torch.bfloat16) - or not tensor.is_contiguous() - ): - dtype = tensor.dtype - if dtype not in (torch.float32, torch.float16, torch.bfloat16): - dtype = torch.float32 - tensor = tensor.to( - device=self.device, - dtype=dtype, - memory_format=torch.contiguous_format, - ) - if tensor.size() != data.size() or data.dim() != 2: - raise ValueError( - "Invalid tensor dimensions for FP8 cast-transpose " - f"(src={tuple(tensor.size())}, dst={tuple(data.size())})" - ) - if not data.is_contiguous(): - raise ValueError( - "FP8 cast-transpose is only supported for `Float8Tensor`s with contiguous data" - ) - if self._fp8_meta is None: - raise ValueError( - "FP8 cast-transpose is only supported for `Float8Tensor`s with FP8 metadata " - ) - - # Construct transpose cache if needed - transpose = self._transpose - if transpose is None or not transpose.is_contiguous(): - transpose = torch.empty( - (data.size(1), data.size(0)), - dtype=torch.uint8, - device=data.device, - ) - self._transpose = transpose - noop_flag = None - - # Launch cast-transpose kernel - fp8_meta_index = int(self._fp8_meta_index) - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=self._fp8_meta_forward, - ) - fp8_meta = self._fp8_meta[fp8_meta_key] - fp8_cast_transpose_fused( - tensor, - fp8_meta, - fp8_meta_index, - self._fp8_dtype, - cast_out=data, - transpose_out=transpose, - scale_inv=self._scale_inv, - noop_flag=noop_flag, - ) - self._transpose_invalid = False - - @torch.no_grad() - def reset_fp8_meta_scale_inv(self) -> None: - """Replace FP8 meta tensor scale-inverse with cached value - - The FP8 meta tensor scale_inv entry corresponding to this - tensor is replaced with the scale_inv value used to construct - the tensor. - - """ - assert self._fp8_meta is not None, "FP8 meta tensors not found." - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=self._fp8_meta_forward, - ) - self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index].copy_(self._scale_inv[0]) - - def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: - """Create `Float8Tensor` with given nominal dtype - - The new tensor has the same underlying FP8 data. - - """ - return Float8Tensor.make_like( - self, - data=self._data, - fp8_attrs=self._fp8_attrs, - dtype=dtype, - ) - - def _reset_caches(self) -> None: - """ - Set transpose cache as invalid. - Should be called after any in-place operation. - """ - self._transpose_invalid = True - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): - - # In-place copy op - if func == aten.copy_.default: - - # Check tensors - dst = args[0] - src = args[1] - if not isinstance(dst, torch.Tensor): - raise RuntimeError("Attempted to copy into something that isn't a PyTorch tensor") - if not isinstance(src, torch.Tensor): - raise RuntimeError("Attempted to copy from something that isn't a PyTorch tensor") - - # Special handling based on which tensors are FP8 - dst_is_fp8 = isinstance(dst, Float8Tensor) - src_is_fp8 = isinstance(src, Float8Tensor) - if dst_is_fp8 and src_is_fp8: - - # Directly copy FP8 data if possible - if dst._fp8_dtype == src._fp8_dtype: - dst._data.copy_(src._data) - dst._scale_inv.copy_(src._scale_inv.detach()) - if dst._fp8_meta is not None: - if src._fp8_meta is None: - src_min, src_max = src.from_float8().aminmax() - src_amax = torch.maximum(-src_min, src_max) - else: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=src._fp8_meta_forward, - ) - fp8_meta_index = src._fp8_meta_index - src_amax = src._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=dst._fp8_meta_forward, - ) - fp8_meta_index = dst._fp8_meta_index - dst_amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] - torch.maximum(src_amax, dst_amax, out=dst_amax) - else: - dst.copy_(src.from_float8()) - - elif not dst_is_fp8 and src_is_fp8: - - # Cast source tensor to higher precision - dst.copy_(src.from_float8()) - - elif dst_is_fp8 and not src_is_fp8: - # Make sure input is in expected format - src = src.expand(dst.size()) - src = src.to( - device=dst.device, - memory_format=torch.contiguous_format, - ) - - # Update scaling factor if FP8 meta tensors are available - if dst._fp8_meta is None: - scale = dst._scale_inv.reciprocal() - amax = torch.empty_like(scale) - else: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=dst._fp8_meta_forward, - ) - fp8_meta_index = dst._fp8_meta_index - scale = dst._fp8_meta[fp8_meta_key].scale[fp8_meta_index] - amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] - - # Cast to FP8 - if not dst._data.is_contiguous(): - raise RuntimeError("Transformer Engine cast kernels require contiguous data") - tex.cast_to_fp8_noalloc( - src.view(1, -1), - scale, - dst._data.view(1, -1), - amax, - dst._scale_inv, - dst._fp8_dtype, - ) - - # This branch is where the FP8 parameters are updated in-place during optimization. - # Handle forward amax reduction. - post_optimizer_step_fwd_amax_reduction(dst) - else: - - # Invalid case - raise RuntimeError("Using Float8Tensor copy logic, but no Float8Tensor found") - - # Nothing to return for in-place ops - if dst_is_fp8: - dst._reset_caches() - - return None - - # Slice op - if func == aten.slice.Tensor: - tensor = args[0] - data = tensor._data - data_slice = data.__torch_dispatch__( - func, - types, - [data] + list(args[1:]), - kwargs, - ) - return Float8Tensor.make_like(tensor, data=data_slice) - - # Detach op - if func == aten.detach.default: - # Simply return a new Float8Tensor with the same attrs - return Float8Tensor.make_like( - args[0], - data=args[0]._data, - fp8_attrs=args[0]._fp8_attrs, - ) - - # View op - if func == aten.view.default: - tensor = args[0] - data = tensor._data - data_view = data.__torch_dispatch__( - func, - types, - [data] + list(args[1:]), - kwargs, - ) - return Float8Tensor.make_like( - tensor, - data=data_view, - fp8_attrs=tensor._fp8_attrs, - ) - - def maybe_unwrap(t): - if isinstance(t, Float8Tensor): - return t.from_float8() - return t - - def maybe_update_inplace(arg, new_arg, schema_arg): - """Update values of FP8 tensors - - Keep the same FP8 scaling factors. - - """ - if ( - isinstance(arg, Float8Tensor) - and isinstance(new_arg, torch.Tensor) - and hasattr(schema_arg, "alias_info") - and hasattr(schema_arg.alias_info, "is_write") - and schema_arg.alias_info.is_write - ): - arg.copy_(new_arg) - arg._reset_caches() - - # In-place op - if func._schema.is_mutable: - # Cast to higher precision, perform op, and cast values - # back to original FP8 buffers - new_args = tree_map(maybe_unwrap, args) - new_kwargs = tree_map(maybe_unwrap, kwargs) - schema_args = func._schema.arguments - args_len = len(args) - out = super().__torch_dispatch__(func, types, new_args, new_kwargs) - for arg, new_arg, schema_arg in zip(args, new_args, schema_args): - maybe_update_inplace(arg, new_arg, schema_arg) - for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]): - assert kwarg == new_kwarg == schema_arg.name, "name of the kw argument should match" - maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg) - return None - - # Default op - # Note: cast to higher precision and perform op - args = tree_map(maybe_unwrap, args) - if kwargs is not None: - kwargs = tree_map(maybe_unwrap, kwargs) - out = super().__torch_dispatch__(func, types, args, kwargs) - return out - - @classmethod - def _make_in_reduce_ex( - cls, - data: torch.Tensor, - fp8_dtype: tex.DType, - fp8_scale_inv: torch.Tensor, - dtype: torch.dtype, - ) -> Float8Tensor: - """Build Float8Tensor, for use in __reduce__ - - __reduce_ex__ assumes object constructor has positional - arguments. - - """ - return Float8Tensor( - data=data, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv, - dtype=dtype, - ) - - def __reduce_ex__(self, protocol: int) -> tuple: - """Custom pickling to remove references to FP8 metadata objects""" - return ( - Float8Tensor._make_in_reduce_ex, - (self._data, self._fp8_dtype, self._scale_inv, self.dtype), - ) - - def _get_data(self) -> Float8Tensor: - """Get tensor data property""" - return super().data - - def _set_data(self, tensor: torch.Tensor) -> None: - """Set tensor data property - - Cast tensor to FP8 and store in FP8 buffer. - - """ - with torch.no_grad(): - self.copy_(tensor) - - # Cast to FP8 when setting Float8Tensor.data - data = property(_get_data, _set_data) - - # Accessors for objects in self._fp8_attrs - # Note: We store FP8 attributes in a dictionary so we can share - # them between tensors with the same data, e.g. detached tensors. - # For convenience, we also expose them as property attributes. - _fp8_meta = property(**_make_fp8_attr_property_funcs("fp8_meta")) - _fp8_meta_forward = property(**_make_fp8_attr_property_funcs("fp8_meta_forward")) - _fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index")) - _fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype")) - _transpose = property(**_make_fp8_attr_property_funcs("transpose")) - _transpose_invalid = property(**_make_fp8_attr_property_funcs("transpose_invalid")) - _scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv")) - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - return torch._C._disabled_torch_function_impl(func, types, args, kwargs) +__all__ = ["Float8Tensor"] diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 3613e1fa5e..3375b8ab7d 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -865,11 +865,17 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # If primary weights are in fp8, wrap the parameter as Float8Tensor fp8_meta_index = self.param_init_meta[name].fp8_meta_index if self.primary_weights_in_fp8 and fp8_meta_index is not None: + dummy_amax = torch.empty( + (1, 1), + dtype=torch.float32, + device=param.device, + ) # Dummy buffer to avoid overwriting amax history param = Float8Tensor.to_float8( param, fp8_meta=self.fp8_meta, fp8_meta_index=fp8_meta_index, - amax=torch.empty(1, device="cuda"), # Dummy amax to avoid overwriting history. + amax=dummy_amax, + with_transpose_cache=torch.is_grad_enabled(), ) # Redo parameter wrap in case we broke it above @@ -891,7 +897,6 @@ def get_fp8_workspace( cache_name: Optional[str] = None, update_workspace: bool = True, skip_update_flag: Optional[torch.Tensor] = None, - with_transpose: bool = False, fsdp_group: dist_group_type = None, ) -> Float8Tensor: """Get FP8 workspace buffer and maybe update its values @@ -917,27 +922,30 @@ def get_fp8_workspace( skip_update_flag: torch.Tensor, optional GPU flag to skip updating the workspace. Take precedence over `update_workspace` if provided. - with_transpose: bool, default = `False` - Whether to initialize cached transpose in workspace. fsdp_group: bool, default = None FSDP process group that the weights are distributed over. """ - # Construct workspace if needed + # Try getting workspace from cache out = None if cache_name is not None: out = self._fp8_workspaces.get(cache_name, None) - # Gather cached Fp8 workspace if it's distributed - # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work - # for models initialized with Fp8 primary weights. - if ( - not isinstance(out, Float8Tensor) - and fsdp_group is not None - and out._data.shape != tensor.data.shape - ): - _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out) + # Gather cached Fp8 workspace if it's distributed + # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work + # for models initialized with Fp8 primary weights. + if ( + out is not None + and not isinstance(out, Float8Tensor) + and fsdp_group is not None + and out._data.shape != tensor.data.shape + ): + _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out) + + # Construct workspace if needed if out is None: + + # FP8 data if tensor is None or fp8_meta_forward is None or fp8_meta_index is None: raise ValueError( "tensor, fp8_meta_forward, and fp8_meta_index kwargs " @@ -947,16 +955,38 @@ def get_fp8_workspace( self.fp8_meta["recipe"], fprop_tensor=fp8_meta_forward, ) + data = torch.empty_like(tensor, dtype=torch.uint8) scale_inv = torch.empty([1], dtype=torch.float32, device=tensor.device) + + # Transpose cache + with_transpose_cache = torch.is_grad_enabled() + if ( + not with_transpose_cache + and is_fp8_activation_recompute_enabled() + and not in_fp8_activation_recompute_phase() + ): + with_transpose_cache = True + data_transpose = None + if with_transpose_cache: + data_transpose = torch.empty( + (tensor.size(-1), tensor.numel() // tensor.size(-1)), + dtype=torch.uint8, + device=tensor.device, + ) + + # Construct FP8 tensor out = Float8Tensor( - data=torch.empty_like(tensor, dtype=torch.uint8), + data=data, fp8_meta=self.fp8_meta, fp8_meta_forward=fp8_meta_forward, fp8_meta_index=fp8_meta_index, fp8_dtype=fp8_dtype, fp8_scale_inv=scale_inv, dtype=tensor.dtype, + data_transpose=data_transpose, ) + + # Update cache if cache_name is not None: self._fp8_workspaces[cache_name] = out update_workspace = True @@ -968,33 +998,17 @@ def get_fp8_workspace( if update_workspace: if tensor is None: raise ValueError("tensor kwarg must be provided to update FP8 workspace") - if with_transpose: - out.cast_transpose_( - tensor, - noop_flag=skip_update_flag, - ) + if is_in_onnx_export_mode(): + # ONNX export does not support fused cast-transpose + # kernel and requires that FP8 scales can be + # represented with constant ops. + transpose_cache = out._transpose + out._transpose = None + out.quantize_(tensor) + out._scale_inv.fill_(out._scale_inv.item()) + out._transpose = transpose_cache else: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=out._fp8_meta_forward, - ) - fp8_meta = out._fp8_meta[fp8_meta_key] - fp8_meta_index = out._fp8_meta_index - cast_to_fp8( - tensor, - fp8_meta, - fp8_meta_index, - out._fp8_dtype, - out=out._data, - ) - if is_in_onnx_export_mode(): - # ONNX export expects FP8 scales can be - # represented with constant ops. However, copying - # into a buffer involves an expand op for array - # broadcasting. We work around this by filling the - # buffer instead. - out._scale_inv.fill_(fp8_meta.scale_inv[fp8_meta_index].item()) - else: - out._scale_inv.copy_(fp8_meta.scale_inv[fp8_meta_index]) + out.quantize_(tensor, noop_flag=skip_update_flag) return out diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index ca100392c7..10c8d91551 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -28,8 +28,6 @@ from ..distributed import ( set_tensor_model_parallel_attributes, get_distributed_world_size, - is_fp8_activation_recompute_enabled, - in_fp8_activation_recompute_phase, ) from ..cpp_extensions import ( cast_to_fp8, @@ -760,22 +758,12 @@ def forward( weight_tensors_fp8 = [None] * self.num_gemms if self.fp8: - with_transpose = torch.is_grad_enabled() - if ( - not with_transpose - and is_fp8_activation_recompute_enabled() - and not in_fp8_activation_recompute_phase() - ): - with_transpose = True for i in range(self.num_gemms): if isinstance(weight_tensors[i], Float8Tensor): - # Fill transpose cache in FP8 tensor if needed - update_transpose_cache = with_transpose - if update_transpose_cache: - update_transpose_cache = ( - is_first_microbatch or skip_fp8_weight_update is not None - ) - if update_transpose_cache: + # Make sure transpose cache is valid, if present + # Note: Transpose cache may have been invalidated + # externally, e.g. by optimizer. + if weight_tensors[i]._transpose is not None: weight_tensors[i].transpose_2d( fill_cache=True, noop_flag=skip_fp8_weight_update, @@ -790,7 +778,6 @@ def forward( cache_name=(None if is_first_microbatch is None else f"weight{i}"), update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, - with_transpose=with_transpose, ) from ..cpu_offload import CPUOffloadEnabled diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 9586d6d345..da77879e06 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -36,8 +36,6 @@ allreduce, reduce_scatter_along_first_dim, gather_along_first_dim, - is_fp8_activation_recompute_enabled, - in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, _fsdp_gather_tensors, ) @@ -47,6 +45,7 @@ from ._common import _apply_normalization, _noop_cat from ..float8_tensor import Float8Tensor from ..export import is_in_onnx_export_mode +from ..tensor import QuantizedTensor __all__ = ["LayerNormLinear"] @@ -1151,14 +1150,14 @@ def forward( # Get concatenated weight and bias tensors unfused_weights = [getattr(self, name) for name in self.weight_names] - if any(isinstance(w, Float8Tensor) for w in unfused_weights): + if any(isinstance(w, QuantizedTensor) for w in unfused_weights): if self.fp8: if len(unfused_weights) != 1: raise RuntimeError( - "Splitting Float8Tensor into multiple params is not supported" + "Splitting QuantizedTensor into multiple params is not supported" ) else: - unfused_weights = [w.from_float8() for w in unfused_weights] + unfused_weights = [w.dequantize() for w in unfused_weights] weight_tensor = _noop_cat(unfused_weights) if self.use_bias: bias_tensor = _noop_cat( @@ -1170,32 +1169,18 @@ def forward( # Initialize FP8 weights if needed weight_fp8 = None if self.fp8: - with_transpose = torch.is_grad_enabled() - if ( - not with_transpose - and is_fp8_activation_recompute_enabled() - and not in_fp8_activation_recompute_phase() - ): - with_transpose = True if isinstance(weight_tensor, Float8Tensor): - # Fill transpose cache in FP8 tensor if needed - update_transpose_cache = with_transpose - if update_transpose_cache: - update_transpose_cache = ( - is_first_microbatch or skip_fp8_weight_update is not None - ) - if update_transpose_cache: + # Make sure transpose cache is valid, if present + # Note: Transpose cache may have been invalidated + # externally, e.g. by optimizer. + if weight_tensor._transpose is not None: weight_tensor.transpose_2d( fill_cache=True, noop_flag=skip_fp8_weight_update, ) else: # FP8 cast to workspace buffer - update_workspace = ( - is_first_microbatch is None - or is_first_microbatch - or skip_fp8_weight_update is not None - ) + update_workspace = is_first_microbatch is None or is_first_microbatch weight_fp8 = self.get_fp8_workspace( tensor=weight_tensor, fp8_meta_forward=True, @@ -1203,7 +1188,6 @@ def forward( cache_name=(None if is_first_microbatch is None else "weight"), update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, - with_transpose=with_transpose, ) from ..cpu_offload import CPUOffloadEnabled diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index dc9bef645f..b802c972d4 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -42,8 +42,6 @@ allreduce, reduce_scatter_along_first_dim, gather_along_first_dim, - is_fp8_activation_recompute_enabled, - in_fp8_activation_recompute_phase, use_reentrant_activation_recompute, _fsdp_scatter_tensors, _fsdp_gather_tensors, @@ -1485,19 +1483,8 @@ def forward( fc2_weight_fp8 = None if self.fp8: update_workspace = is_first_microbatch is None or is_first_microbatch - with_transpose = torch.is_grad_enabled() - if ( - is_fp8_activation_recompute_enabled() - and not in_fp8_activation_recompute_phase() - ): - with_transpose = True - update_transpose_cache = with_transpose - if update_transpose_cache: - update_transpose_cache = ( - is_first_microbatch or skip_fp8_weight_update is not None - ) if isinstance(fc1_weight, Float8Tensor): - if update_transpose_cache: + if fc1_weight._transpose is not None: fc1_weight.transpose_2d( fill_cache=True, noop_flag=skip_fp8_weight_update, @@ -1513,10 +1500,9 @@ def forward( cache_name=cache_name, update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, - with_transpose=with_transpose, ) if isinstance(fc2_weight, Float8Tensor): - if update_transpose_cache: + if fc2_weight._transpose is not None: fc2_weight.transpose_2d( fill_cache=True, noop_flag=skip_fp8_weight_update, @@ -1532,7 +1518,6 @@ def forward( cache_name=cache_name, update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, - with_transpose=with_transpose, ) # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f92a2db2d9..a7be82ccf1 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -33,8 +33,6 @@ allreduce, reduce_scatter_along_first_dim, gather_along_first_dim, - is_fp8_activation_recompute_enabled, - in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, _fsdp_gather_tensors, ) @@ -49,6 +47,7 @@ from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor from ..export import is_in_onnx_export_mode +from ..tensor import QuantizedTensor __all__ = ["Linear"] @@ -938,19 +937,19 @@ def forward( with self.prepare_forward( inp, is_first_microbatch, - allow_non_contiguous=isinstance(inp, Float8Tensor), + allow_non_contiguous=isinstance(inp, QuantizedTensor), ) as inp: # Get concatenated weight and bias tensors unfused_weights = [getattr(self, name) for name in self.weight_names] - if any(isinstance(w, Float8Tensor) for w in unfused_weights): + if any(isinstance(w, QuantizedTensor) for w in unfused_weights): if self.fp8: if len(unfused_weights) != 1: raise RuntimeError( - "Splitting Float8Tensor into multiple params is not supported" + "Splitting QuantizedTensor into multiple params is not supported" ) else: - unfused_weights = [w.from_float8() for w in unfused_weights] + unfused_weights = [w.dequantize() for w in unfused_weights] weight_tensor = _noop_cat(unfused_weights) if self.use_bias: bias_tensor = _noop_cat( @@ -962,21 +961,11 @@ def forward( # Initialize FP8 weights if needed weight_fp8 = None if self.fp8: - with_transpose = torch.is_grad_enabled() - if ( - not with_transpose - and is_fp8_activation_recompute_enabled() - and not in_fp8_activation_recompute_phase() - ): - with_transpose = True if isinstance(weight_tensor, Float8Tensor): - # Fill transpose cache in FP8 tensor if needed - update_transpose_cache = with_transpose - if update_transpose_cache: - update_transpose_cache = ( - is_first_microbatch or skip_fp8_weight_update is not None - ) - if update_transpose_cache: + # Make sure transpose cache is valid, if present + # Note: Transpose cache may have been invalidated + # externally, e.g. by optimizer. + if weight_tensor._transpose is not None: weight_tensor.transpose_2d( fill_cache=True, noop_flag=skip_fp8_weight_update, @@ -991,7 +980,6 @@ def forward( cache_name=(None if is_first_microbatch is None else "weight"), update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, - with_transpose=with_transpose, fsdp_group=self.fsdp_group, ) diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 77efef4ab6..12270d8340 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -9,54 +9,12 @@ import torch -from transformer_engine.pytorch.float8_tensor import Float8Tensor - - -def canonicalize_device(device: Optional[torch.device | str]) -> torch.device: - """Canonicalize PyTorch device - - If `None`, then returns the default CUDA device. - - """ - if device is None: - # Use default CUDA device - device = torch.get_default_device() - if device.type != "cuda": - device = torch.device("cuda", torch.cuda.current_device()) - elif not isinstance(device, torch.device): - device = torch.device(device) - if device.type == "cuda" and device.index is None: - device = torch.device("cuda", torch.cuda.current_device()) - return device - - -def canonicalize_dtype(dtype: Optional[torch.dtype]) -> torch.dtype: - """Canonicalize PyTorch datatype - - If `None`, then returns the default PyTorch datatype. - - """ - if dtype is None: - # Use default dtype - dtype = torch.get_default_dtype() - return dtype - - -def devices_match(device1: torch.device, device2: torch.device) -> bool: - """Whether two devices are the same""" - device1 = torch.device(device1) - device2 = torch.device(device2) - if device1.type != device2.type: - return False - if device1.type == "cuda": - index1 = device1.index - index2 = device2.index - if index1 is None: - index1 = torch.cuda.current_device() - if index2 is None: - index2 = torch.cuda.current_device() - return index1 == index2 - return device1 == device2 +from ..tensor import Float8Tensor +from ..utils import ( + canonicalize_device, # pylint: disable=unused-import + canonicalize_dtype, # pylint: disable=unused-import + devices_match, # pylint: disable=unused-import +) def is_float8_tensor(tensor: Any) -> bool: @@ -92,7 +50,13 @@ def convert_tensor( # Convert FP8 tensor if is_float8_tensor(tensor): - data = tensor._data.to(device=device, memory_format=memory_format) + data = tensor._data + if not devices_match(device, data.device): + data = data.to(device=device) + if memory_format != torch.preserve_format and not data.is_contiguous( + memory_format=memory_format + ): + data = data.contiguous(memory_format=memory_format) return Float8Tensor.make_like( tensor, data=data, diff --git a/transformer_engine/pytorch/ops/basic/all_reduce.py b/transformer_engine/pytorch/ops/basic/all_reduce.py index 622346b1c5..f466ade3a3 100644 --- a/transformer_engine/pytorch/ops/basic/all_reduce.py +++ b/transformer_engine/pytorch/ops/basic/all_reduce.py @@ -9,11 +9,8 @@ import torch -from transformer_engine.pytorch.ops.op import ( - BasicOperation, - OperationContext, -) -from .._common import is_float8_tensor +from ...tensor import QuantizedTensor +from ..op import BasicOperation, OperationContext class AllReduce(BasicOperation): @@ -54,8 +51,8 @@ def op_forward( # Perform all-reduce x = input_ - if is_float8_tensor(x): - x = x.from_float8() + if isinstance(x, QuantizedTensor): + x = x.dequantize() x = x.contiguous() torch.distributed.all_reduce(x, group=self.process_group) return x diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 826807d1c0..ce72dd8a55 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -289,10 +289,18 @@ def reset_parameters(self) -> None: # Cast to FP8 if needed if self._with_fp8_parameters: + dummy_amax = torch.empty( + (1, 1), + dtype=torch.float32, + device=self.device, + ) # Dummy buffer to avoid overwriting amax history weight = Float8Tensor.to_float8( weight, fp8_meta=self.get_fp8_meta("param"), + fp8_meta_forward=True, fp8_meta_index=0, + amax=dummy_amax, + with_transpose_cache=torch.is_grad_enabled(), ) # Save updated parameter @@ -467,25 +475,19 @@ def _functional_forward( input_fp8_meta["recipe"], fprop_tensor=True, ) - x_fp8 = Float8Tensor( - data=torch.empty_like(x_local, dtype=torch.uint8), + with_transpose_cache = weight.requires_grad + if tensor_parallel_mode == "column" and sequence_parallel: + with_transpose_cache = False + x_local = Float8Tensor.to_float8( + x_local, fp8_meta=input_fp8_meta, fp8_meta_forward=True, fp8_meta_index=0, fp8_dtype=fp8_dtype, - fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), - dtype=dtype, + with_transpose_cache=with_transpose_cache, ) - with_cast_transpose = weight.requires_grad - if tensor_parallel_mode == "column" and sequence_parallel: - with_cast_transpose = False - if with_cast_transpose: - x_fp8.cast_transpose_(x_local) - else: - x_fp8.copy_(x_local) - x_local = x_fp8 elif not with_fp8_compute and is_float8_tensor(x_local): - x_local = x_local.from_float8() + x_local = x_local.dequantize() x = x_local x_async = None if tensor_parallel_mode == "column" and sequence_parallel: @@ -510,11 +512,12 @@ def _functional_forward( w = Float8Tensor.to_float8( w, fp8_meta=weight_fp8_meta, + fp8_meta_forward=True, fp8_meta_index=0, fp8_dtype=fp8_dtype, ) elif not with_fp8_compute and is_float8_tensor(w): - w = w.from_float8() + w = w.dequantize() # Check bias tensor b = None @@ -815,25 +818,19 @@ def _functional_backward( grad_output_fp8_meta["recipe"], fprop_tensor=False, ) - dy_fp8 = Float8Tensor( - data=torch.empty_like(dy, dtype=torch.uint8), + with_transpose_cache = weight_requires_grad + if tensor_parallel_mode == "row" and sequence_parallel: + with_transpose_cache = False + dy = Float8Tensor.to_float8( + dy, fp8_meta=grad_output_fp8_meta, fp8_meta_forward=False, fp8_meta_index=0, fp8_dtype=fp8_dtype, - fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), - dtype=dtype, + with_transpose_cache=with_transpose_cache, ) - with_cast_transpose = weight_requires_grad - if tensor_parallel_mode == "row" and sequence_parallel: - with_cast_transpose = False - if with_cast_transpose: - dy_fp8.cast_transpose_(dy) - else: - dy_fp8.copy_(dy) - dy = dy_fp8 elif not with_fp8_compute and is_float8_tensor(dy): - dy = dy.from_float8() + dy = dy.dequantize() if tensor_parallel_mode == "row" and sequence_parallel: dy, dy_async = gather_along_first_dim( dy, @@ -853,26 +850,24 @@ def _functional_backward( device=device, dtype=dtype, ) + x_is_sharded = tensor_parallel_mode == "column" and sequence_parallel if with_fp8_compute and not is_float8_tensor(x_local): fp8_dtype = get_fp8_te_dtype( input_fp8_meta["recipe"], fprop_tensor=True, ) - x_fp8 = Float8Tensor( - data=torch.empty_like(x_local, dtype=torch.uint8), + x_local = Float8Tensor.to_float8( + x_local, fp8_meta=input_fp8_meta, fp8_meta_forward=True, fp8_meta_index=0, fp8_dtype=fp8_dtype, - fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), - dtype=dtype, + with_transpose_cache=(not x_is_sharded), ) - x_fp8.cast_transpose_(x_local) - x_local = x_fp8 elif not with_fp8_compute and is_float8_tensor(x_local): x_local = x_local.from_float8() x = x_local - if tensor_parallel_mode == "column" and sequence_parallel: + if x_is_sharded: x, x_async = gather_along_first_dim( x_local, tensor_parallel_group, @@ -898,19 +893,16 @@ def _functional_backward( weight_fp8_meta["recipe"], fprop_tensor=True, ) - w_fp8 = Float8Tensor( - data=torch.empty_like(w, dtype=torch.uint8), + w = Float8Tensor.to_float8( + w, fp8_meta=weight_fp8_meta, fp8_meta_forward=True, fp8_meta_index=0, fp8_dtype=fp8_dtype, - fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), - dtype=dtype, + with_transpose_cache=True, ) - w_fp8.cast_transpose_(w) - w = w_fp8 elif not with_fp8_compute and is_float8_tensor(w): - w = w.from_float8() + w = w.dequantize() # Construct grad input tensor if grad_input is not None: diff --git a/transformer_engine/pytorch/ops/basic/reduce_scatter.py b/transformer_engine/pytorch/ops/basic/reduce_scatter.py index 996ca2da31..c78dbc2877 100644 --- a/transformer_engine/pytorch/ops/basic/reduce_scatter.py +++ b/transformer_engine/pytorch/ops/basic/reduce_scatter.py @@ -9,12 +9,9 @@ import torch -from transformer_engine.pytorch.float8_tensor import Float8Tensor -from transformer_engine.pytorch.ops.op import ( - BasicOperation, - OperationContext, -) -from .._common import convert_tensor, is_float8_tensor +from ...tensor import Float8Tensor, QuantizedTensor +from ..op import BasicOperation, OperationContext +from .._common import convert_tensor class ReduceScatter(BasicOperation): @@ -63,8 +60,8 @@ def op_forward( # Check input tensor x = input_ - if is_float8_tensor(x): - x = x.from_float8() + if isinstance(x, QuantizedTensor): + x = x.dequantize() x = x.contiguous() # Perform reduce-scatter @@ -96,7 +93,7 @@ def op_backward( # Perform all-gather dy = convert_tensor(grad_output, memory_format=torch.contiguous_format) dx = None - if is_float8_tensor(dy): + if isinstance(dy, Float8Tensor): dx = Float8Tensor.make_like( dy, data=torch.empty( @@ -111,6 +108,8 @@ def op_backward( group=self.process_group, ) else: + if isinstance(dy, QuantizedTensor): + dy = dy.dequantize() dx = torch.empty(input_dims, dtype=dy.dtype, device=dy.device) torch.distributed.all_gather_into_tensor( dx, diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py new file mode 100644 index 0000000000..2bad862768 --- /dev/null +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Custom tensor classes""" + +from .float8_tensor import Float8Tensor +from .quantized_tensor import QuantizedTensor diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py new file mode 100644 index 0000000000..610523a10d --- /dev/null +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -0,0 +1,972 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tensor class with FP8 data""" +from __future__ import annotations +from typing import Any, Dict, Optional, Tuple +import warnings + +import torch +import transformer_engine_torch as tex + +from transformer_engine_torch import DType as TE_DType +from ..constants import TE_DType as torch_to_transformer_engine_dtype +from ..cpp_extensions import ( + cast_from_fp8, + cast_to_fp8, + fp8_cast_transpose_fused, +) +from ..fp8 import FP8GlobalStateManager +from ..utils import devices_match +from .quantized_tensor import QuantizedTensor + +aten = torch.ops.aten +updated_fp8_params = {} + + +def _make_fp8_attr_property_funcs(name: str) -> Any: + """Make accessors for an FP8 attribute + + We store FP8 attributes in a dictionary so we can share them + between tensors with the same data, e.g. detached tensors. For + convenience, we also expose them as property attributes. This + function creates the accessors for property attributes. + + Parameters + ---------- + name: str + Key in dictionary of FP8 attributes + + """ + + def get_func(self) -> Any: + return self._fp8_attrs[name] + + def set_func(self, value: Any) -> None: + self._fp8_attrs[name] = value + + def del_func(self) -> None: + del self._fp8_attrs[name] + + return dict(fget=get_func, fset=set_func, fdel=del_func) + + +class _FromFloat8Func(torch.autograd.Function): + """Cast from FP8 to other dtype""" + + @staticmethod + def forward( + _ctx: torch.autograd.function.FunctionCtx, # unused + tensor: Float8Tensor, + dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: + return tensor.dequantize(dtype=dtype) + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # Assume that we want gradients in full precision + return grad, None + + +def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None: + """Amax scale and update when there is at least 1 trainable FP8 parameter.""" + param_id = id(param._data) + + if param_id not in FP8GlobalStateManager.fp8_param_to_autocast: + return + + autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id] + + if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params: + return + + if autocast_key in updated_fp8_params: + updated_fp8_params[autocast_key].add(param_id) + else: + updated_fp8_params[autocast_key] = {param_id} + + current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key] + # All FP8 trainable parameters have been updated. + if updated_fp8_params[autocast_key] == current_fp8_params_set: + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True, fp8_weights=True) + del updated_fp8_params[autocast_key] + + +class _ToFloat8Func(torch.autograd.Function): + """Cast to FP8 from other dtype""" + + @staticmethod + def forward( + _ctx: torch.autograd.function.FunctionCtx, # unused + tensor: torch.Tensor, + fp8_meta: Optional[Dict[str, Any]] = None, + fp8_meta_forward: bool = True, + fp8_meta_index: Optional[int] = None, + fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + with_transpose_cache: bool = False, + ) -> Float8Tensor: + + # Tensor attributes + dtype = tensor.dtype + if dtype not in (torch.float32, torch.bfloat16, torch.float16): + dtype = torch.float32 + device = tensor.device + if device.type != "cuda": + device = torch.device("cuda") + + # FP8 data buffer + data = torch.empty(tensor.size(), dtype=torch.uint8, device=device) + + # Check scale + if scale is None and fp8_meta is None: + scale = 1 + if scale is not None: + if isinstance(scale, torch.Tensor): + scale = scale.to(device=device, dtype=torch.float32) + else: + scale = torch.full([1], scale, dtype=torch.float32, device=device) + + # Check scale-inverse + if scale_inv is None: + scale_inv = torch.empty([1], dtype=torch.float32, device=device) + elif not devices_match(scale_inv.device, device) or scale_inv.dtype != dtype: + scale_inv = scale_inv.to(device=device, dtype=torch.float32) + + # Transpose cache + data_transpose = None + if with_transpose_cache: + data_transpose = torch.empty( + (data.size(-1), data.numel() // data.size(-1)), + dtype=torch.uint8, + device=tensor.device, + ) + + # Construct FP8 tensor + out = Float8Tensor( + data=data, + fp8_meta=fp8_meta, + fp8_meta_forward=fp8_meta_forward, + fp8_meta_index=fp8_meta_index, + fp8_dtype=fp8_dtype, + fp8_scale_inv=scale_inv, + dtype=dtype, + data_transpose=data_transpose, + ) + + # Cast to FP8 tensor + out.quantize_(tensor, scale=scale, amax=amax) + + return out + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # Assume that we want gradients in full precision + return grad, None, None, None, None, None, None, None + + +class _IdentityFunc(torch.autograd.Function): + """Identity function + + If constructor keyword-arguments are provided, then construct a + new Float8Tensor using the provided tensor's attributes. + + """ + + @staticmethod + def forward( + ctx, + tensor: Float8Tensor, + init_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + + # Return input tensor if constructor kwargs are not provided + ctx.input_dtype = tensor.dtype + if init_kwargs is None: + return tensor + + # Construct new tensor if constructor kwargs are provided + default_kwargs = dict( + data=tensor._data, + fp8_meta=tensor._fp8_meta, + fp8_meta_forward=tensor._fp8_meta_forward, + fp8_meta_index=tensor._fp8_meta_index, + fp8_dtype=tensor._fp8_dtype, + fp8_scale_inv=tensor._scale_inv, + dtype=tensor.dtype, + ) + for key, val in default_kwargs.items(): + if key not in init_kwargs: + init_kwargs[key] = val + return Float8Tensor(**init_kwargs) + + @staticmethod + def backward(ctx, grad): + return grad.to(ctx.input_dtype), None + + +class _ViewFunc(torch.autograd.Function): + """View function + + View the Float8Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: torch.Tensor, + shape: Tuple[int] = None, + ) -> torch.Tensor: + + # Return input tensor if shape is not provided + ctx.shape = tensor.shape + if shape is None: + return tensor + + # Construct new tensor if shape is provided + if isinstance(tensor, Float8Tensor): + return Float8Tensor.make_like( + tensor, + data=tensor._data.view(*shape), + ) + return tensor.view(*shape) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + + if isinstance(grad, Float8Tensor): + dgrad = Float8Tensor.make_like( + grad, + data=grad._data.view(ctx.shape), + ) + return dgrad, None + return grad.view(ctx.shape), None + + +class _ReshapeFunc(torch.autograd.Function): + """Reshape function + + Reshape the Float8Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: torch.Tensor, + shape: Tuple[int] = None, + ) -> torch.Tensor: + + # Return input tensor if shape is not provided + ctx.shape = tensor.shape + if shape is None: + return tensor + + # Construct new tensor if shape is provided + if isinstance(tensor, Float8Tensor): + return Float8Tensor.make_like( + tensor, + data=tensor._data.reshape(*shape), + ) + return tensor.reshape(*shape) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + + if isinstance(grad, Float8Tensor): + dgrad = Float8Tensor.make_like( + grad, + data=grad._data.reshape(ctx.shape), + ) + return dgrad, None + return grad.reshape(ctx.shape), None + + +class Float8Tensor(QuantizedTensor): + """Experimental tensor class with FP8 data + + The tensor presents as having a standard, higher-precision dtype, + but the data itself is (scaled) FP8. For most tensor operations, + the data will be cast to the nominal dtype before performing the + operation. + + Parameters + ---------- + data: torch.Tensor + Raw FP8 data in a uint8 tensor + fp8_attrs: dict, optional + FP8 metadata, primarily managed by Float8Tensor. If + provided, all other FP8 configuration is ignored. + fp8_meta: dict, optional + FP8 metadata object, primarily managed by TE modules. + fp8_meta_forward: bool, default = `True` + Whether to access the FP8 metadata for the + forward pass. Ignored if fp8_meta is not + provided. + fp8_meta_index: int, optional + Index to access in FP8 meta tensors. Required if + fp8_meta is provided and otherwise ignored. + fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3 + FP8 format. + fp8_scale_inv: torch.Tensor + Reciprocal of the scaling factor applied when + casting to FP8, i.e. the scaling factor that must + be applied when casting from FP8 to higher + precision. Can be inferred from fp8_meta if + provided. + dtype: torch.dtype, default = torch.float32 + Nominal tensor datatype. + + """ + + def __new__( + cls, + *, + data: torch.Tensor, + fp8_attrs: Optional[Dict[str, Any]] = None, + fp8_meta: Optional[Dict[str, Any]] = None, + fp8_meta_forward: bool = True, + fp8_meta_index: Optional[int] = None, + fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, + fp8_scale_inv: Optional[torch.Tensor] = None, + dtype: torch.dtype = torch.float32, + data_transpose: Optional[torch.Tensor] = None, + ): + + # Check that data buffer is valid + if data.element_size() != 1: + raise ValueError( + f"Float8Tensor requires data buffer with 8-bit dtype (got dtype={data.dtype})" + ) + if data.requires_grad: + raise ValueError("Float8Tensor requires non-differentiable data buffer") + if not data.is_cuda: + data = data.cuda() + + # Initialize tensor object + self = torch.Tensor._make_wrapper_subclass( + cls, + data.size(), + strides=data.stride(), + storage_offset=data.storage_offset(), + dtype=dtype, + layout=data.layout, + requires_grad=data.requires_grad, + device=data.device, + ) + self._data: torch.Tensor = data + + # Initialize dict of class attributes + # Note: We store FP8 attributes in a dictionary so we can + # share them between tensors with the same data, e.g. detached + # tensors. + self._fp8_attrs: dict + if fp8_attrs is None: + self._fp8_attrs = {} + else: + self._fp8_attrs = fp8_attrs + return self + + # FP8 meta tensors + if fp8_meta is not None and fp8_meta_index is None: + raise ValueError( + "To initialize Float8Tensor with FP8 meta tensors, " + "the FP8 meta tensor index must also be provided" + ) + self._fp8_meta: Optional[Dict[str, Any]] = fp8_meta + self._fp8_meta_forward: bool = fp8_meta_forward + self._fp8_meta_index: Optional[int] = fp8_meta_index + + # FP8 dtype + assert fp8_dtype in ( + TE_DType.kFloat8E4M3, + TE_DType.kFloat8E5M2, + ), f"Unsupported fp8_dtype {fp8_dtype}." + self._fp8_dtype: TE_DType = fp8_dtype + + # FP8 scale-inverse + if fp8_scale_inv is None and self._fp8_meta is not None: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=self._fp8_meta_forward, + ) + fp8_scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index] + fp8_scale_inv = fp8_scale_inv.detach().view(1).clone() + if fp8_scale_inv is None: + raise ValueError( + "Attempted to initialize Float8Tensor without specifying scale-inverse" + ) + if not isinstance(fp8_scale_inv, torch.Tensor): + fp8_scale_inv = torch.full( + [1], + fp8_scale_inv, + dtype=torch.float32, + device=self._data.device, + ) + if fp8_scale_inv.numel() != 1: + raise ValueError( + "Attempted to initialize Float8Tensor with invalid scale-inverse tensor" + ) + if fp8_scale_inv.dim() != 1: + fp8_scale_inv = fp8_scale_inv.reshape(1) + if ( + not devices_match(fp8_scale_inv.device, self._data.device) + or fp8_scale_inv.dtype != torch.float32 + ): + fp8_scale_inv = fp8_scale_inv.to( + device=self._data.device, + dtype=torch.float32, + ) + self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv + + # FP8 transpose cache + self._transpose: Optional[Float8Tensor] = data_transpose + self._transpose_invalid: bool = self._transpose is None + + return self + + @classmethod + def make_like( + cls, + tensor: Float8Tensor, + *, + data: torch.Tensor, + fp8_attrs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Float8Tensor: + """Use attributes of a Float8Tensor to create another Float8Tensor + + See constructor for list of keyword arguments. + + """ + default_kwargs = dict( + fp8_meta=tensor._fp8_meta, + fp8_meta_forward=tensor._fp8_meta_forward, + fp8_meta_index=tensor._fp8_meta_index, + fp8_dtype=tensor._fp8_dtype, + fp8_scale_inv=tensor._scale_inv, + dtype=tensor.dtype, + ) + for key, val in default_kwargs.items(): + if key not in kwargs: + kwargs[key] = val + return Float8Tensor(data=data, fp8_attrs=fp8_attrs, **kwargs) + + def __repr__(self): + return ( + "Float8Tensor(" + f"fp8_dtype={self._fp8_dtype}, " + f"scale_inv={self._scale_inv.item()}, " + f"data={self.from_float8(dtype=self.dtype)}" + ")" + ) + + def dequantize(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + + # Convert PyTorch dtype to TE dtype + if dtype is None: + dtype = self.dtype + dtype = torch_to_transformer_engine_dtype[dtype] + + # Make sure FP8 data is in expected format + data = self._data + if data.device.type != "cuda": + data = data.cuda() + if not data.is_contiguous(): + data = data.contiguous() + if data.dim() != 2: + data = data.view(1, -1) + + # Cast from FP8 + out = cast_from_fp8( + data.view(1, -1), + None, # fp8_meta_tensor + None, # fp8_tensor + self._fp8_dtype, + dtype, + scale_inv=self._scale_inv, + ) + + # Make sure output is in expected format + if out.size() != self.size(): + out = out.view(self.size()) + return out + + def from_float8(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Construct plain PyTorch tensor from Float8Tensor + + By default the resulting tensor's dtype is the + Float8Tensor's nominal dtype. + """ + return _FromFloat8Func.apply(self, dtype) + + def quantize_( + self, + tensor: torch.Tensor, + *, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + noop_flag: Optional[torch.Tensor] = None, + ) -> Float8Tensor: + """Update FP8 data + + Parameters + ---------- + tensor: torch.Tensor + Tensor to copy from + scale: torch.Tensor, optional + Scaling factor to use for FP8 quantization + amax: torch.Tensor, optional + History of maximum absolute values. The first entry will + be updated with the absmax of `tensor`. + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid performing update + + """ + src = tensor + dst = self + + # In-place operations invalidate transpose cache + self._reset_caches() + + # Special logic if other tensor is Float8Tensor + if isinstance(src, Float8Tensor): + + # Cast to plain tensor if FP8 dtypes don't match + if dst._fp8_dtype != src._fp8_dtype: + return dst.quantize_(src.dequantize()) + + # Directly copy FP8 data + dst._data.copy_(src._data.detach()) + dst._scale_inv.copy_(src._scale_inv.detach()) + if amax is not None or dst._fp8_meta is not None: + src_amax: torch.Tensor + if src._fp8_meta is None: + src_min, src_max = src.dequantize().aminmax() + src_amax = torch.maximum(-src_min, src_max) + else: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=src._fp8_meta_forward, + ) + fp8_meta_index = src._fp8_meta_index + src_amax = src._fp8_meta[fp8_meta_key].amax_history[0, fp8_meta_index] + dst_amax: torch.Tensor + if amax is None: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=dst._fp8_meta_forward, + ) + fp8_meta_index = dst._fp8_meta_index + dst_amax = dst._fp8_meta[fp8_meta_key].amax_history[0, fp8_meta_index] + else: + dst_amax = amax + if dst_amax.dim() > 0: + dst_amax = dst_amax[tuple([0] * dst_amax.dim())] + torch.maximum(src_amax, dst_amax, out=dst_amax) + if dst._transpose is not None: + if src._transpose is None: + dst.transpose_2d(force_compute=True, fill_cache=True) + else: + dst._transpose.copy_(src._transpose) + dst._transpose_invalid = False + return self + + # Convert QuantizedTensor to plain tensor + if isinstance(src, QuantizedTensor): + return dst.quantize_(src.dequantize()) + + # Make sure input is in expected format + if src.size() != dst.size(): + src = src.expand(dst.size()) + if not devices_match(src.device, dst.device): + src = src.to(device=dst.device) + if src.dtype not in (torch.float32, torch.bfloat16, torch.float16): + src = src.float() + if not src.is_contiguous(): + src = src.contiguous() + + # Make sure FP8 scaling factors are in expected format + if scale is not None: + if isinstance(scale, torch.Tensor): + if not devices_match(scale.device, dst.device) or scale.dtype != torch.float32: + scale = scale.to(device=dst.device, dtype=torch.float32) + else: + scale = torch.full([1], scale, dtype=torch.float32, device=dst.device) + if amax is not None: + while amax.dim() < 2: + amax = amax.unsqueeze(0) + if not devices_match(amax.device, dst.device): + raise ValueError( + f"Invalid device for amax (expected {dst.device}, found {amax.device})" + ) + if amax.dtype != torch.float32: + raise ValueError(f"Invalid dtype for amax (expected float32, found {amax.type})") + + # Default FP8 scaling factors + fp8_meta = None + if dst._fp8_meta is None: + if scale is None: + scale = dst._scale_inv.reciprocal() + if amax is None: + amax = torch.empty((1, 1), dtype=torch.float32, device=dst.device) + else: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=dst._fp8_meta_forward, + ) + fp8_meta = dst._fp8_meta[fp8_meta_key] + + # Check local data + if not dst._data.is_contiguous(): + raise RuntimeError("Transformer Engine cast kernels require contiguous data") + + # Perform FP8 cast + if dst._transpose is None: + dst_data = dst._data + if src.dim() != 2: + src = src.view(1, -1) + dst_data = dst_data.view(1, -1) + cast_to_fp8( + src, + fp8_meta, + dst._fp8_meta_index, + dst._fp8_dtype, + out=dst_data, + scale=scale, + amax=amax, + scale_inv=dst._scale_inv, + ) + else: + fp8_cast_transpose_fused( + src.view(-1, src.size(-1)), + fp8_meta, + dst._fp8_meta_index, + dst._fp8_dtype, + cast_out=dst._data, + transpose_out=dst._transpose, + scale=scale, + amax=amax, + scale_inv=dst._scale_inv, + noop_flag=noop_flag, + ) + dst._transpose_invalid = False + + # Callback hook to perform amax reduction after optimizer step + post_optimizer_step_fwd_amax_reduction(self) + + return self + + @classmethod + def to_float8( + cls, + tensor: torch.Tensor, + *, + fp8_meta: Optional[Dict[str, Any]] = None, + fp8_meta_forward: bool = True, + fp8_meta_index: Optional[int] = None, + fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + with_transpose_cache: bool = False, + ): + """Construct Float8Tensor from plain PyTorch tensor""" + return _ToFloat8Func.apply( + tensor, + fp8_meta, + fp8_meta_forward, + fp8_meta_index, + fp8_dtype, + scale, + amax, + scale_inv, + with_transpose_cache, + ) + + def detach(self) -> Float8Tensor: + return Float8Tensor.make_like( + self, + data=self._data, + fp8_attrs=self._fp8_attrs, + ) + + def clone(self) -> Float8Tensor: + data = self._data.detach().clone() + data_transpose = None + if self._transpose is not None: + data_transpose = self._transpose.detach().clone() + return _IdentityFunc.apply( + self, + dict( + data=data, + data_transpose=data_transpose, + ), + ) + + def view(self, *shape: Tuple[int]) -> Float8Tensor: + return _ViewFunc.apply(self, shape) + + def reshape(self, *shape: Tuple[int]) -> Float8Tensor: + return _ReshapeFunc.apply(self, shape) + + def contiguous( + self, + *, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> Float8Tensor: + """Returns tensor with data in provided memory format + + Returns `self` if data is already in correct memory format. + + """ + if self._data.is_contiguous(memory_format=memory_format): + return self + return _IdentityFunc.apply( + self, + {"data": self._data.detach().contiguous(memory_format=memory_format)}, + ) + + def transpose_2d( + self, + *, + force_compute: bool = False, + fill_cache: bool = False, + noop_flag: Optional[torch.Tensor] = None, + cache: Optional[bool] = None, + ) -> torch.Tensor: + """ + 2D transpose with caching support. + + Parameters + ---------- + force_compute: bool, default = `False` + Force computation of transpose. Otherwise use + cached values, if possible. + fill_cache: bool, default = `False` + Cache output tensor for future function calls. + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid updating + cached values, if possible. + cache: bool, deprecated + + """ + + # Handle deprecated cache kwarg + if cache is not None: + msg = ( + "cache kwarg for Float8Tensor.transpose_2d is deprecated, " + "please use force_compute and fill_cache instead" + ) + warnings.warn(msg, DeprecationWarning) + if cache: + force_compute = False + fill_cache = True + else: + force_compute = True + fill_cache = False + + # Need to compute transpose if cache is invalid + need_compute = force_compute + if self._transpose is None: + need_compute = True + elif self._transpose_invalid: + need_compute = True + + # Need to apply transpose kernel if noop flag is applied + if noop_flag is not None: + need_compute = True + + # Return cached transpose if possible + if not need_compute: + return self._transpose + + # Allocate output if needed + data = self._data.contiguous().reshape(-1, self.size(-1)) + out = self._transpose + if out is None: + out = torch.empty( + (data.size(1), data.size(0)), + dtype=torch.uint8, + device=data.device, + ) + noop_flag = None + else: + self._transpose_invalid = False + + # Apply transpose kernel + fp8_dtype = self._fp8_dtype + if noop_flag is None: + tex.fp8_transpose_noalloc(data, out, fp8_dtype) + else: + noop_flag = noop_flag.to(dtype=torch.float32, device=data.device) + tex.fp8_transpose_noalloc_noop(data, out, noop_flag, fp8_dtype) + + # Fill cache if needed + if fill_cache: + self._transpose = out + self._transpose_invalid = False + + return out + + @torch.no_grad() + def cast_transpose_( + self, + tensor: torch.Tensor, + noop_flag: Optional[torch.Tensor] = None, + ) -> None: + """Cast from tensor and populate transpose cache + + Tensor is reshaped as a 2D matrix. + + Parameters + ---------- + tensor: torch.Tensor + Tensor to copy from. Must have same dimensions as + destination tensor. + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid updating + destination tensor. + + """ + if self._transpose is None: + self._transpose = torch.empty( + (self.size(-1), self.numel() // self.size(-1)), + dtype=torch.uint8, + device=self.device, + ) + self.quantize_(tensor, noop_flag=noop_flag) + + @torch.no_grad() + def reset_fp8_meta_scale_inv(self) -> None: + """Replace FP8 meta tensor scale-inverse with cached value + + The FP8 meta tensor scale_inv entry corresponding to this + tensor is replaced with the scale_inv value used to construct + the tensor. + + """ + assert self._fp8_meta is not None, "FP8 meta tensors not found." + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=self._fp8_meta_forward, + ) + self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index].copy_(self._scale_inv[0]) + + def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: + """Create `Float8Tensor` with given nominal dtype + + The new tensor has the same underlying FP8 data. + + """ + return Float8Tensor.make_like( + self, + data=self._data, + fp8_attrs=self._fp8_attrs, + dtype=dtype, + ) + + def _reset_caches(self) -> None: + """ + Set transpose cache as invalid. + Should be called after any in-place operation. + """ + self._transpose_invalid = True + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + # Slice op + if func == aten.slice.Tensor: + tensor = args[0] + data = tensor._data + data_slice = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return Float8Tensor.make_like(tensor, data=data_slice) + + # View op + if func == aten.view.default: + tensor = args[0] + data = tensor._data + data_view = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return Float8Tensor.make_like(tensor, data=data_view) + + # Default case + return super().__torch_dispatch__(func, types, args, kwargs) + + @classmethod + def _make_in_reduce_ex( + cls, + data: torch.Tensor, + fp8_dtype: TE_DType, + fp8_scale_inv: torch.Tensor, + dtype: torch.dtype, + ) -> Float8Tensor: + """Build Float8Tensor, for use in __reduce__ + + __reduce_ex__ assumes object constructor has positional + arguments. + + """ + return Float8Tensor( + data=data, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + dtype=dtype, + ) + + def __reduce_ex__(self, protocol: int) -> tuple: + """Custom pickling to remove references to FP8 metadata objects""" + return ( + Float8Tensor._make_in_reduce_ex, + (self._data, self._fp8_dtype, self._scale_inv, self.dtype), + ) + + def _get_data(self) -> Float8Tensor: + """Get tensor data property""" + return super().data + + def _set_data(self, tensor: torch.Tensor) -> None: + """Set tensor data property + + Cast tensor to FP8 and store in FP8 buffer. + + """ + with torch.no_grad(): + self.copy_(tensor) + + # Cast to FP8 when setting Float8Tensor.data + data = property(_get_data, _set_data) + + # Accessors for objects in self._fp8_attrs + # Note: We store FP8 attributes in a dictionary so we can share + # them between tensors with the same data, e.g. detached tensors. + # For convenience, we also expose them as property attributes. + _fp8_meta = property(**_make_fp8_attr_property_funcs("fp8_meta")) + _fp8_meta_forward = property(**_make_fp8_attr_property_funcs("fp8_meta_forward")) + _fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index")) + _fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype")) + _transpose = property(**_make_fp8_attr_property_funcs("transpose")) + _transpose_invalid = property(**_make_fp8_attr_property_funcs("transpose_invalid")) + _scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv")) diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py new file mode 100644 index 0000000000..f890b0878a --- /dev/null +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -0,0 +1,172 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tensor with quantized data""" + +from __future__ import annotations +from typing import Optional, Tuple + +import torch +from torch.utils._pytree import tree_map + + +class _DequantizeFunc(torch.autograd.Function): + """Autograd function to convert quantized tensor to standard tensor""" + + @staticmethod + def forward( + _ctx: torch.autograd.function.FunctionCtx, # unused + tensor: QuantizedTensor, + dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: + return tensor.dequantize(dtype=dtype) + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + return grad, None + + +class _IdentityFunc(torch.autograd.Function): + """Autograd function to create quantized tensor with same data""" + + @staticmethod + def forward( + _ctx: torch.autograd.function.FunctionCtx, # unused + tensor: QuantizedTensor, + ) -> QuantizedTensor: + return tensor.detach() + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> torch.Tensor: + return grad + + +class QuantizedTensor(torch.Tensor): + """Abstract base class for tensor with quantized data + + This is a proxy class with the interface of a standard PyTorch + tensor, but with data that has been encoded with some quantization + scheme. Derived classes should implement the quantization scheme + by overriding the `quantize_` and `dequantize` functions. + + """ + + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """Convert quantized data to standard PyTorch tensor""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement dequantize function" + ) + + def quantize_(self, tensor: torch.Tensor) -> QuantizedTensor: + """Update quantized data in-place""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement quantize_ function" + ) + + def detach(self) -> QuantizedTensor: + """Create new quantized tensor with same data + + Output tensor must be detached from the current autograd + graph. + + """ + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement detach function" + ) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})" + + def float(self) -> torch.Tensor: + return _DequantizeFunc.apply(self, torch.float32) + + def bfloat16(self) -> torch.Tensor: + return _DequantizeFunc.apply(self, torch.bfloat16) + + def half(self) -> torch.Tensor: + return _DequantizeFunc.apply(self, torch.float16) + + def cpu(self) -> torch.Tensor: + return _DequantizeFunc.apply(self).cpu() + + def expand_as(self, other: torch.Tensor) -> torch.Tensor: + if other is self: + # Note: expand_as is hackily used to create dummy autograd nodes + # and access the backward graph (see + # https://github.com/pytorch/pytorch/blob/238fb660851268f44ff88127887041fea352fe48/torch/nn/parallel/distributed.py#L1026). + # We hackily add a dummy function to handle this case. + return _IdentityFunc.apply(self) + return super().expand_as(other) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + # Detach op + if func == torch.ops.aten.detach.default: + return args[0].detach() + + # In-place copy op + if func == torch.ops.aten.copy_.default: + dst = args[0] + src = args[1] + if isinstance(dst, QuantizedTensor): + dst.quantize_(src) + else: + if isinstance(src, QuantizedTensor): + src = src.dequantize() + dst.copy_(src) + return None + + # View op + if func == torch.ops.aten.view.default: + raise NotImplementedError("{cls.__name__} class does not support tensor views") + + def maybe_unwrap(arg): + if isinstance(arg, QuantizedTensor): + return arg.dequantize(dtype=arg.dtype) + return arg + + def maybe_update_inplace(arg, new_arg, schema_arg): + if ( + isinstance(arg, QuantizedTensor) + and isinstance(new_arg, torch.Tensor) + and hasattr(schema_arg, "alias_info") + and hasattr(schema_arg.alias_info, "is_write") + and schema_arg.alias_info.is_write + ): + arg.quantize_(new_arg) + + # In-place op: dequantize, perform op, and quantize + if func._schema.is_mutable: + new_args = tree_map(maybe_unwrap, args) + new_kwargs = tree_map(maybe_unwrap, kwargs) + schema_args = func._schema.arguments + args_len = len(args) + super().__torch_dispatch__(func, types, new_args, new_kwargs) + for arg, new_arg, schema_arg in zip(args, new_args, schema_args): + maybe_update_inplace(arg, new_arg, schema_arg) + for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]): + assert kwarg == new_kwarg == schema_arg.name, "name of the kw argument should match" + maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg) + return None + + # Default op: dequantize and perform op + args = tree_map(maybe_unwrap, args) + if kwargs is not None: + kwargs = tree_map(maybe_unwrap, kwargs) + out = super().__torch_dispatch__(func, types, args, kwargs) + return out + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + # Do not force the QuantizedTensor type on the returned tensor + return torch._C._disabled_torch_function_impl(func, types, args, kwargs) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 5e3fa05f52..d5145455b8 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Utility functions for Transformer Engine modules""" +from __future__ import annotations import functools import math from typing import Any, Callable, Optional, Tuple @@ -251,3 +252,52 @@ def get_cudnn_version() -> Tuple[int, int, int]: major, encoded_version = divmod(encoded_version, major_version_magnitude) minor, patch = divmod(encoded_version, 100) return (major, minor, patch) + + +def canonicalize_device(device: Optional[torch.device | str]) -> torch.device: + """Canonicalize PyTorch device + + If `None`, then returns the default CUDA device. + + """ + if device is None: + # Use default CUDA device + device = torch.get_default_device() + if device.type != "cuda": + device = torch.device("cuda", torch.cuda.current_device()) + elif not isinstance(device, torch.device): + device = torch.device(device) + if device.type == "cuda" and device.index is None: + device = torch.device("cuda", torch.cuda.current_device()) + return device + + +def canonicalize_dtype(dtype: Optional[torch.dtype]) -> torch.dtype: + """Canonicalize PyTorch datatype + + If `None`, then returns the default PyTorch datatype. + + """ + if dtype is None: + # Use default dtype + dtype = torch.get_default_dtype() + return dtype + + +def devices_match(device1: torch.device, device2: torch.device) -> bool: + """Whether two devices are the same""" + device1 = torch.device(device1) + device2 = torch.device(device2) + if device1.type != device2.type: + return False + if device1.type == "cuda": + index1 = device1.index + index2 = device2.index + if index1 == index2: + return True + if index1 is None: + index1 = torch.cuda.current_device() + if index2 is None: + index2 = torch.cuda.current_device() + return index1 == index2 + return device1 == device2 From e6e060303dd3b4614c6fe85b8fcf0063108b7fc8 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 11 Sep 2024 07:36:53 -0700 Subject: [PATCH 33/50] [PyTorch] Lower atol/rtol for F16 attention tests (#1157) * reduce atol/rtol for F16 tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * relax the tols for Ampere Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_fused_attn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index a1ebead04a..d110dece53 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -233,9 +233,9 @@ def test_dot_product_attention( """Test DotProductAttention module""" # Get configs - tols = dict(atol=5e-3, rtol=5e-3) + tols = dict(atol=1e-3, rtol=1e-3) if dtype == torch.bfloat16: - tols = dict(atol=2.5e-2, rtol=2.5e-2) + tols = dict(atol=1.5e-2, rtol=1.5e-2) config = model_configs[model] is_mla = config.head_dim_qk != config.head_dim_v if qkv_layout is None: @@ -1035,7 +1035,7 @@ def test_transformer_layer( # Get configs config = model_configs[model] - tols = dict(atol=5e-1, rtol=5e-2) + tols = dict(atol=5e-2, rtol=5e-2) workspace_opt = True # Test backend availability From c55007b85aa1a6563dfd2d45f5353ba5e3cfe54f Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 11 Sep 2024 12:13:21 -0700 Subject: [PATCH 34/50] Update CI users (#1180) Signed-off-by: Tim Moon --- .github/workflows/trigger-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index c66571b536..660bf916ed 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -34,6 +34,7 @@ jobs: || github.actor == 'phu0ngng' || github.actor == 'xrennvidia' || github.actor == 'yaox12' + || github.actor == 'huanghua1994' ) steps: - name: Check if comment is issued by authorized person From df699655e696f9f58c87576c44a58b393aebadb3 Mon Sep 17 00:00:00 2001 From: Michael Goldfarb Date: Mon, 16 Sep 2024 10:06:08 -0500 Subject: [PATCH 35/50] [JAX] Fix unit tests to work around cuDNN 9.4 regression of 0 length sequences (#1179) Modify unit tests to work around cuDNN 9.4 regression. Signed-off-by: Michael Goldfarb --- tests/jax/test_fused_attn.py | 63 ++++++++++++++++++++++++++++++------ 1 file changed, 53 insertions(+), 10 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 390a3e2c4e..2709eeab30 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -29,7 +29,10 @@ get_qkv_format, ) from transformer_engine.jax.cpp_extensions import FusedAttnHelper -from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend +from transformer_engine.transformer_engine_jax import ( + NVTE_Fused_Attn_Backend, + get_cudnn_version, +) from utils import assert_allclose @@ -230,7 +233,14 @@ def customcall_fused_dpa( kwargs.pop("max_segments_per_seq") return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype) return fused_attn_thd( - qkv_args, bias, seqlens_q, seqlens_kv, offsets_q, offsets_kv, dropout_rng, **kwargs + qkv_args, + bias, + seqlens_q, + seqlens_kv, + offsets_q, + offsets_kv, + dropout_rng, + **kwargs, ).astype(query.dtype) @@ -265,6 +275,15 @@ class FusedAttnRunner: qkv_layout: QKVLayout bias_shape: BiasShape + # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue + # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases. + def _get_max_segments_per_sequence(self): + if 90400 <= get_cudnn_version() < 90500: + return self.num_segments_per_seq + else: + # +1 for testing runtime_segments < max_segments + return self.num_segments_per_seq + 1 + def _check_configs(self): # TODO(rewang): probably adds this in is_fused_attn_available if get_qkv_format(self.qkv_layout) == QKVFormat.THD and not self.attn_mask_type in [ @@ -299,7 +318,10 @@ def _check_configs(self): self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS and self.bias_shape != BiasShape.BIAS_1HSS ): - if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: + if self.attn_mask_type not in [ + AttnMaskType.NO_MASK, + AttnMaskType.CAUSAL_MASK, + ]: pytest.skip( "B1SS, BHSS and 11SS bias shapes are only supported for " "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK." @@ -316,7 +338,12 @@ def _setup_inputs(self): q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5) q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim) - k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim) + k_shape = v_shape = ( + self.batch_size, + self.max_seqlen_kv, + self.num_heads_kv, + self.head_dim, + ) if self.attn_bias_type == AttnBiasType.NO_BIAS: bias_shape = None @@ -325,7 +352,12 @@ def _setup_inputs(self): elif self.bias_shape == BiasShape.BIAS_B1SS: bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv) elif self.bias_shape == BiasShape.BIAS_BHSS: - bias_shape = (self.batch_size, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv) + bias_shape = ( + self.batch_size, + self.num_heads_q, + self.max_seqlen_q, + self.max_seqlen_kv, + ) elif self.bias_shape == BiasShape.BIAS_11SS: bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv) else: @@ -405,7 +437,10 @@ def generate_random_segment_ids( self.segment_pad_kv = self.segment_pad_q else: self.token_kv, self.segment_pad_kv = generate_random_segment_ids( - self.batch_size, self.max_seqlen_kv, self.num_segments_per_seq, seed=2024 + self.batch_size, + self.max_seqlen_kv, + self.num_segments_per_seq, + seed=2024, ) self.pad_q = self.segment_pad_q self.pad_kv = self.segment_pad_kv @@ -464,8 +499,7 @@ def test_forward(self): "dropout_probability": self.dropout_prob, "is_training": self.is_training, "qkv_layout": self.qkv_layout, - # +1 for testing runtime_segments < max_segments - "max_segments_per_seq": self.num_segments_per_seq + 1, + "max_segments_per_seq": self._get_max_segments_per_sequence(), } # Convert the outputs to float32 for the elementwise comparison @@ -522,7 +556,7 @@ def grad_func(func, *args, **kwargs): "dropout_probability": self.dropout_prob, "is_training": self.is_training, "qkv_layout": self.qkv_layout, - "max_segments_per_seq": self.num_segments_per_seq + 1, + "max_segments_per_seq": self._get_max_segments_per_sequence(), } # We can compute dBias only for the [1, h, s, s] layout @@ -635,7 +669,16 @@ def check_dqkv(primitive, reference, pad): pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"), pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"), pytest.param(4, 512, 128, 16, 16, 64, jnp.bfloat16, id="4-512-128-16-16-64-BF16-CROSS"), - pytest.param(2, 2048, 1024, 12, 12, 64, jnp.bfloat16, id="2-2048-1048-12-12-64-BF16-CROSS"), + pytest.param( + 2, + 2048, + 1024, + 12, + 12, + 64, + jnp.bfloat16, + id="2-2048-1048-12-12-64-BF16-CROSS", + ), pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"), pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"), ], From af5daa09e1bc6e29779ebfb1fc9bd56634218de5 Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Mon, 16 Sep 2024 10:08:55 -0700 Subject: [PATCH 36/50] Add dtensor support for TE optimizers (#1171) add dtensor support for te optimizers Signed-off-by: jasonwan Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/optimizers/multi_tensor_apply.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/transformer_engine/pytorch/optimizers/multi_tensor_apply.py b/transformer_engine/pytorch/optimizers/multi_tensor_apply.py index b8d6d1f263..191b57eab9 100644 --- a/transformer_engine/pytorch/optimizers/multi_tensor_apply.py +++ b/transformer_engine/pytorch/optimizers/multi_tensor_apply.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Multi-tensor apply entry.""" +from torch.distributed._tensor import DTensor class MultiTensorApply: # pylint: disable=too-few-public-methods @@ -12,6 +13,11 @@ def __init__(self, chunk_size): self.chunk_size = chunk_size def __call__(self, op, noop_flag_buffer, tensor_lists, *args): + for i, ts in enumerate(tensor_lists): + for j, t in enumerate(ts): + if isinstance(t, DTensor): + tensor_lists[i][j] = t._local_tensor + return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args) From d2d4cf9142d522562f6e05b2e5768619bd6a1356 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Mon, 16 Sep 2024 13:46:45 -0700 Subject: [PATCH 37/50] Update CI users (#1181) Update list of CI users Signed-off-by: Tim Moon --- .github/workflows/trigger-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index 660bf916ed..c25aa863ad 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -35,6 +35,7 @@ jobs: || github.actor == 'xrennvidia' || github.actor == 'yaox12' || github.actor == 'huanghua1994' + || github.actor == 'mgoldfarb-nvidia' ) steps: - name: Check if comment is issued by authorized person From 9101a78f124258ef76c598410a846645bff359c9 Mon Sep 17 00:00:00 2001 From: Michael Goldfarb Date: Tue, 17 Sep 2024 09:26:15 -0500 Subject: [PATCH 38/50] [JAX] Context Parallel Attention with All-Gather (#1106) Implementation of context parallel fused attention using all-gather. Signed-off-by: Michael Goldfarb --- tests/jax/distributed_test_base.py | 24 + tests/jax/test_distributed_fused_attn.py | 245 +++++- transformer_engine/jax/attention.py | 119 ++- .../jax/cpp_extensions/attention.py | 772 ++++++++++++------ .../jax/csrc/extensions/pybind.cpp | 5 +- transformer_engine/jax/sharding.py | 39 +- 6 files changed, 942 insertions(+), 262 deletions(-) diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index 3a7fe33378..bbd54ecce5 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -4,6 +4,8 @@ import operator import re from functools import reduce +from itertools import product +import pytest import jax from jax.experimental.pjit import pjit, _UNSPECIFIED @@ -29,6 +31,28 @@ def generate_configs(): return configs +def generate_context_parallel_configs(): + configs = [] + + DP_sizes = (1, 2) + CP_sizes = (1, 2, 4, 8) + TP_sizes = (1, 2) + for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes): + ndev = cp * tp * dp + if is_devices_enough(ndev): + configs.append( + pytest.param( + ndev, + (dp, cp, tp), + ("dp", "cp", "tp"), + MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp"), + id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}", + ) + ) + + return configs + + COLL_AR_KEY = "all-reduce" COLL_AG_KEY = "all-gather" COLL_OTHER_KEY = "other" diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 15676dd270..61d68aacae 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -3,6 +3,7 @@ # See LICENSE for license information. import pytest +from functools import partial import jax import jax.numpy as jnp @@ -10,8 +11,13 @@ from flax.linen import dot_product_attention from jax import random from jax.sharding import Mesh, NamedSharding, PartitionSpec -from distributed_test_base import generate_configs, generate_collectives_count, compare_ops -from utils import make_causal_mask, make_self_mask +from distributed_test_base import ( + generate_configs, + generate_context_parallel_configs, + generate_collectives_count, + compare_ops, +) +from utils import make_causal_mask, make_self_mask, assert_tree_like_allclose, assert_allclose from transformer_engine.jax import fp8_autocast from transformer_engine.jax.attention import ( is_fused_attn_kernel_available, @@ -19,6 +25,10 @@ AttnBiasType, AttnMaskType, QKVLayout, + QKVFormat, + get_qkv_format, + reorder_causal_load_balancing, + inverse_reorder_causal_load_balancing, ) @@ -263,7 +273,8 @@ def target_func(q, kv, mask): scaling_factor=scaling_factor, dropout_probability=dropout_prob, is_training=is_training, - ) + ), + dtype=jnp.float32, ) def ref_func(query, kv, mask): @@ -284,7 +295,7 @@ def ref_func(query, kv, mask): dtype=jnp.float32, ) - return jnp.mean(output).astype(dtype) + return jnp.mean(output, dtype=jnp.float32) (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs( data_shape, mesh_resource, attn_mask_type, dtype @@ -310,3 +321,229 @@ def ref_func(query, kv, mask): in_shardings=(q_pspec, kv_pspec, mask_pspec), out_shardings=(None, (q_pspec, kv_pspec)), ) + + +class TestDistributedContexParallelSelfAttn: + + def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype): + batch, seqlen, heads, hidden = shape + qkey, kkey, vkey = random.split(random.PRNGKey(1124), 3) + q = random.normal(qkey, shape, dtype=dtype) + k = random.normal(kkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype) + v = random.normal(vkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype) + + mask = None + if attn_mask_type == AttnMaskType.CAUSAL_MASK: + mask = make_causal_mask(batch, seqlen) + + return q, k, v, mask + + def qkv_to_layout(self, q, k, v, qkv_layout): + qkv_args = () + match qkv_layout: + case QKVLayout.BSHD_BS2HD: + k, v = map(partial(jnp.expand_dims, axis=-3), [k, v]) + kv = jnp.concatenate((k, v), axis=-3) + qkv_args = (q, kv) + case QKVLayout.BSHD_BSHD_BSHD: + qkv_args = (q, k, v) + case _: + raise ValueError(f"Unsupported {qkv_layout=}") + return qkv_args + + @pytest.mark.parametrize( + "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() + ) + @pytest.mark.parametrize( + "data_shape", + [ + pytest.param([2, 512, 12, 128], id="2-512-12-128"), + pytest.param([4, 1024, 16, 64], id="4-1024-16-64"), + ], + ) + @pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16]) + @pytest.mark.parametrize( + "attn_mask_type", + [ + pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"), + pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"), + ], + ) + @pytest.mark.parametrize("dtype", [jnp.bfloat16]) + @pytest.mark.parametrize( + "qkv_layout", + [ + pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"), + pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"), + ], + ) + @pytest.mark.parametrize( + "load_balanced", [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")] + ) + def test_contex_parallel_self_attn( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + kv_groups, + attn_mask_type, + dtype, + qkv_layout, + load_balanced, + ): + attn_bias_type = AttnBiasType.NO_BIAS + dropout_prob = 0.0 + is_training = True + scaling_factor = 1.0 + dp_size, cp_size, tp_size = mesh_shape + qkv_format = get_qkv_format(qkv_layout) + + _, seqlen, num_head, hidden = data_shape + num_kv_heads = num_head // kv_groups + + # make sure the mesh evently divides cp and tp axis + if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0: + pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}") + + def target_func(q, k, v, mask): + return jnp.mean( + fused_attn( + self.qkv_to_layout(q, k, v, qkv_layout), + bias=None, + mask=mask, + seed=None, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_prob, + is_training=is_training, + context_parallel_causal_load_balanced=load_balanced, + ), + ).astype(dtype) + + def ref_func(q, k, v, mask, kv_groups): + q = jnp.squeeze(q) + k = jnp.squeeze(jnp.repeat(k, kv_groups, axis=2)) + v = jnp.squeeze(jnp.repeat(v, kv_groups, axis=2)) + output = dot_product_attention( + q, + k, + v, + bias=None, + mask=mask, + deterministic=is_training, + dropout_rate=dropout_prob, + dropout_rng=None, + dtype=jnp.float32, + ) + return jnp.mean(output).astype(dtype) + + q, k, v, mask = self.generate_inputs(data_shape, kv_groups, attn_mask_type, dtype) + + # Single GPU (reference) + ref_func_jit = jax.jit(jax.value_and_grad(ref_func, argnums=[0, 1, 2]), static_argnums=[4]) + ref_fwd, ref_grads = ref_func_jit(q, k, v, mask, kv_groups) + + # Multi GPU (function under test) + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + with mesh, fp8_autocast(mesh_resource=mesh_resource): + qkv_ps = PartitionSpec( + mesh_resource.dp_resource, + mesh_resource.cp_resource, + mesh_resource.tp_resource, + None, + ) + qkv_sharding = NamedSharding(mesh, qkv_ps) + + mask_ps = PartitionSpec( + mesh_resource.dp_resource, None, mesh_resource.cp_resource, None + ) + mask_sharding = NamedSharding(mesh, mask_ps) + + reorder = partial( + reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format + ) + inverse_reorder = partial( + inverse_reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format + ) + + if load_balanced: + q, k, v = jax.tree.map(reorder, (q, k, v)) + + q_, k_, v_ = map(partial(jax.device_put, device=qkv_sharding), [q, k, v]) + mask_ = jax.device_put(mask, device=mask_sharding) + + target_func_jit = jax.jit( + jax.value_and_grad(target_func, argnums=[0, 1, 2]), + in_shardings=[qkv_sharding, qkv_sharding, qkv_sharding, mask_sharding], + out_shardings=(None, (qkv_sharding, qkv_sharding, qkv_sharding)), + ) + + target_fwd, target_grads = target_func_jit(q_, k_, v_, mask_) + + if load_balanced: + target_dq, target_dk, target_dv = jax.tree.map(inverse_reorder, target_grads[0:3]) + target_grads = (target_dq, target_dk, target_dv, *target_grads[3:]) + + def _print_diffs(target, ref): + print("min: ", jnp.min(target), jnp.min(ref)) + print("max: ", jnp.max(target), jnp.max(ref)) + print("mean: ", jnp.mean(target), jnp.mean(ref)) + print("median: ", jnp.median(target), jnp.median(ref)) + print("std: ", jnp.std(target), jnp.std(ref)) + print("var: ", jnp.var(target), jnp.var(ref)) + print("max diff: ", jnp.max(jnp.abs(target - ref))) + + has_diffs = False + + try: + assert_allclose(target_fwd, ref_fwd, dtype=dtype) + except AssertionError as e: + has_diffs = True + print(f"target_fwd v. ref_fwd") + _print_diffs(target_fwd, ref_fwd) + + for i in range(len(target_grads)): + if ref_grads[i] is None or target_grads[i] is None: + # expect both none if one is + assert target_grads[i] is None and ref_grads[i] is None + else: + try: + assert_allclose(target_grads[i], ref_grads[i]) + except AssertionError as e: + has_diffs = True + print(f"target_grads[{i}] v. ref_grads[{i}]") + _print_diffs(target_grads[i], ref_grads[i]) + + assert has_diffs == False, "has_diffs != False" + + +class TestReorderCausalLoadBalancing: + @pytest.mark.parametrize("cp_size", [2, 4, 8]) + @pytest.mark.parametrize( + "shape", + [ + pytest.param([1, 16, 1, 1], id="1-16-1-1"), + pytest.param([4, 32, 12, 32], id="4-32-12-32"), + pytest.param([3, 32, 8, 64], id="3-32-8-64"), + ], + ) + @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD]) + def test(self, cp_size, shape, qkv_format): + tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16) + if qkv_format == QKVFormat.SBHD: + tensor = tensor.swapaxes(0, 1) + + ref = tensor.copy() + + reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2]) + inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2]) + + reordered = reorder(tensor, cp_size, qkv_format) + inversed = inverse(reordered, cp_size, qkv_format) + + assert jnp.array_equal(inversed, ref) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index dcd860c3a4..9b8279be25 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -43,6 +43,8 @@ class AttnMaskType(Enum): PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK PADDING_CAUSAL_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK + CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK + PADDING_CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK class QKVLayout(Enum): @@ -97,11 +99,21 @@ def canonicalize_attn_mask_type(attn_mask_type: str): return AttnMaskType.PADDING_MASK case "causal": return AttnMaskType.CAUSAL_MASK + case "causal_bottom_right" | "bottom_right_causal": + return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK case "padding_causal" | "causal_padding": return AttnMaskType.PADDING_CAUSAL_MASK + case ( + "padding_causal_bottom_right" + | "causal_padding_bottom_right" + | "bottom_right_causal_padding" + | "bottom_right_padding_causal" + ): + return AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK raise ValueError( - f"Unsupported {attn_mask_type=}, supported attn_mask_type=" - "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}" + f"Unsupported {attn_mask_type=}, supported attn_mask_type={{'no_mask', 'padding', 'causal'," + " 'padding_causal', 'causal_padding', 'causal_bottom_right'," + " 'padding_causal_bottom_right'}" ) @@ -155,6 +167,75 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout): return batch, q_max_seqlen, kv_max_seqlen +def _reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat, inverse: bool): + match tensor_format: + case QKVFormat.SBHD: + seq_dim = 0 + case QKVFormat.BSHD: + seq_dim = 1 + case _: + raise ValueError(f"{tensor_format=} is not supported for causal load balancing.") + + if cp_size == 1: + return tensor + + if cp_size % 2 != 0: + raise ValueError(f"{cp_size=} must be a multiple of 2.") + + # Need to ensure we have 2 pairs to swap for balancing between cp ranks + if tensor.shape[seq_dim] % (cp_size * 2) != 0: + raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}") + + # [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D] + # [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D] + ori_tensor_shape = tensor.shape + tensor = tensor.reshape( + ( + *ori_tensor_shape[:seq_dim], + 2 * cp_size, + ori_tensor_shape[seq_dim] // (2 * cp_size), + *ori_tensor_shape[seq_dim + 1 :], + ) + ) + + parts = [] + if not inverse: + for cp_rank in range(cp_size): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) + else: + for cp_rank in range(cp_size // 2): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + base = 4 * cp_rank + index = jnp.array([base, base + 2]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) + for cp_rank in range(cp_size // 2): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + base = 2 * cp_size - 1 - 4 * cp_rank + index = jnp.array([base, base - 2]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) + + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] + combined = jnp.stack(parts, axis=seq_dim) + + return combined.reshape(ori_tensor_shape) + + +def reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat): + """Reorders a tensor for load balancing the compute of causal attention.""" + return _reorder_causal_load_balancing(tensor, cp_size, tensor_format, False) + + +def inverse_reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat): + """Inverse operation of `reorder_causal_load_balancing`.""" + return _reorder_causal_load_balancing(tensor, cp_size, tensor_format, True) + + def fused_attn( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], @@ -166,6 +247,8 @@ def fused_attn( scaling_factor: float, dropout_probability: float, is_training: bool, + context_parallel_causal_load_balanced: bool = False, + context_parallel_axis: str = "", ): """ Perform non-THD (non-packed) cuDNN fused attention. @@ -192,6 +275,9 @@ def fused_attn( scaling_factor (float): Scaling factor for the attention scores. dropout_probability (float): Dropout probability to apply during attention. is_training (bool): Flag indicating whether the model is in training mode. + context_parallel_causal_load_balanced (bool): + Indicates the sequences are ordered for causal mask load balancing when running context parallelism. + context_parallel_axis (str): The name of the context parallel axis. Returns: (jnp.ndarray): The output tensor from the fused attention. """ @@ -213,7 +299,11 @@ def fused_attn( ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" # convert the mask to seqlens, mask doesn't support ragged offsets - if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: + if attn_mask_type in [ + AttnMaskType.NO_MASK, + AttnMaskType.CAUSAL_MASK, + AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, + ]: batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout) q_seq_lens = jnp.full((batch,), q_max_seqlen, dtype=jnp.int32) kv_seq_lens = jnp.full((batch,), kv_max_seqlen, dtype=jnp.int32) @@ -242,6 +332,8 @@ def fused_attn( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=1, + context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, + context_parallel_axis=context_parallel_axis, ) return output @@ -262,6 +354,8 @@ def fused_attn_thd( dropout_probability: float, is_training: bool, max_segments_per_seq: int = 1, + context_parallel_causal_load_balanced: bool = False, + context_parallel_axis: str = "", ): """ (Experimental) Perform THD (packed) cuDNN fused attention. @@ -300,6 +394,9 @@ def fused_attn_thd( Indicating the maximum number of segments inside a sequence. This parameter is to constrain the limit usage and need to be static during the e2e training. The XLA compile time and memory consumption is proportional to `max_segments_per_seq`. + context_parallel_causal_load_balanced (bool): + Indicates the sequences are ordered for causal mask load balancing when running context parallelism. + context_parallel_axis (str): The name of the context parallel axis. Returns: (jnp.ndarray): The output tensor from the fused attention. @@ -354,12 +451,14 @@ def fused_attn_thd( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, + context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, + context_parallel_axis=context_parallel_axis, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15)) def _fused_attn( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], @@ -375,6 +474,8 @@ def _fused_attn( dropout_probability: float, is_training: bool, max_segments_per_seq: int, + context_parallel_causal_load_balanced: bool, + context_parallel_axis: str, ): output, _ = _fused_attn_fwd_rule( qkv, @@ -391,6 +492,8 @@ def _fused_attn( dropout_probability, is_training, max_segments_per_seq, + context_parallel_causal_load_balanced, + context_parallel_axis, ) return output @@ -410,6 +513,8 @@ def _fused_attn_fwd_rule( dropout_probability, is_training, max_segments_per_seq, + context_parallel_causal_load_balanced, + context_parallel_axis, ): output, softmax_aux, rng_state = tex.fused_attn_fwd( qkv, @@ -426,6 +531,8 @@ def _fused_attn_fwd_rule( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, + context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, + context_parallel_axis=context_parallel_axis, ) output = checkpoint_name(output, "context") softmax_aux = checkpoint_name(softmax_aux, "context") @@ -451,6 +558,8 @@ def _fused_attn_bwd_rule( dropout_probability, is_training, max_segments_per_seq, + context_parallel_causal_load_balanced, + context_parallel_axis, ctx, dz, ): @@ -483,6 +592,8 @@ def _fused_attn_bwd_rule( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, + context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, + context_parallel_axis=context_parallel_axis, ) if attn_bias_type == AttnBiasType.NO_BIAS: grad_bias = None diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 0cbf847dcd..d5b901c107 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -9,8 +9,9 @@ from typing import Optional, Tuple import warnings +import jax import jax.numpy as jnp -from jax import dtypes +from jax import dtypes, lax from jax.interpreters import mlir from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding @@ -34,7 +35,11 @@ get_cudnn_version, ) from ..sharding import ( + global_mesh_resource, + lax_paral_op, all_reduce_sum_along_dp_fsdp, + get_mesh_axis_size, + get_mesh_axis_rank, get_all_mesh_axes, num_of_devices, ) @@ -47,6 +52,38 @@ ] +@partial( + jax.tree_util.register_dataclass, + data_fields=[], + meta_fields=[ + "attn_bias_type", + "attn_mask_type", + "qkv_layout", + "scaling_factor", + "dropout_probability", + "is_training", + "max_segments_per_seq", + "context_parallel_load_balanced", + "cp_axis", + ], +) +@dataclass(frozen=True) +class _FusedAttnConfig: + """ + Passes static configuration properties of fused attention. + """ + + attn_bias_type: NVTE_Bias_Type + attn_mask_type: NVTE_Mask_Type + qkv_layout: NVTE_QKV_Layout + scaling_factor: float + dropout_probability: float + is_training: bool + max_segments_per_seq: int + context_parallel_load_balanced: bool + cp_axis: str + + @dataclass(frozen=True) class FusedAttnHelper: """ @@ -178,7 +215,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): name = "te_fused_attn_forward" multiple_results = True - impl_static_args = (9, 10, 11, 12, 13, 14, 15) + impl_static_args = (9,) inner_primitive = None outer_primitive = None @@ -194,13 +231,7 @@ def abstract( _k_seq_offsets, seed_aval, *, - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, + config: _FusedAttnConfig, ): """ Fused attention fwd abstract @@ -213,7 +244,7 @@ def abstract( assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( - FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) + FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) ) output_shape = (*batch_shape, q_max_seqlen, attn_heads, head_dim) @@ -223,10 +254,10 @@ def abstract( backend = FusedAttnHelper( q_dtype, k_dtype, - qkv_layout, - attn_bias_type, - attn_mask_type, - dropout_probability, + config.qkv_layout, + config.attn_bias_type, + config.attn_mask_type, + config.dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, @@ -238,7 +269,7 @@ def abstract( softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) softmax_dtype = q_dtype elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: - softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, max_segments_per_seq) + softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, config.max_segments_per_seq) softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: raise ValueError(f"Unsupported {backend=}") @@ -252,7 +283,7 @@ def abstract( rng_state_shape = (seed_aval.shape[0], checker.rng_state_size) rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype) - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: bias_batch = bias_heads = 0 else: *bias_batch_shape, bias_heads, _, _ = bias_aval.shape @@ -270,14 +301,14 @@ def abstract( num_gqa_groups, bias_heads, head_dim, - scaling_factor, - dropout_probability, - attn_bias_type, - attn_mask_type, - qkv_layout, + config.scaling_factor, + config.dropout_probability, + config.attn_bias_type, + config.attn_mask_type, + config.qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), - is_training, - max_segments_per_seq, + config.is_training, + config.max_segments_per_seq, ) wkspace_aval = q_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) @@ -308,28 +339,12 @@ def lowering( k_seq_offsets, seed, *, - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, + config: _FusedAttnConfig, ): """ Fused attention fwd lowering rules """ - operands = [ - q, - k, - v, - bias, - q_cu_seqlen, - kv_cu_seqlen, - q_seq_offsets, - k_seq_offsets, - seed, - ] + operands = [q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, q_seq_offsets, k_seq_offsets, seed] operand_shapes = map(lambda x: x.type.shape, operands) out_types = [ ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) @@ -340,12 +355,12 @@ def lowering( q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( - FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) + FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) ) input_batch = reduce(operator.mul, batch_shape) - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: bias_batch = bias_heads = 0 else: *bias_batch_shape, bias_heads, _, _ = bias_aval.shape @@ -362,16 +377,16 @@ def lowering( num_gqa_groups, bias_heads, head_dim, - max_segments_per_seq, + config.max_segments_per_seq, wkspace_aval.size, - scaling_factor, - dropout_probability, - attn_bias_type, - attn_mask_type, - qkv_layout, + config.scaling_factor, + config.dropout_probability, + config.attn_bias_type, + config.attn_mask_type, + config.qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - is_training, + config.is_training, not FusedAttnHelper.is_non_deterministic_allowed(), ) @@ -390,17 +405,11 @@ def impl( q_seq_offsets, k_seq_offsets, seed, - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, + config: _FusedAttnConfig, ): assert FusedAttnFwdPrimitive.inner_primitive is not None - if nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD: + if nvte_get_qkv_format(config.qkv_layout) == NVTE_QKV_Format.NVTE_THD: def _fix_len_take(x, condition, fill_value=-1): x_shape = x.shape @@ -418,7 +427,7 @@ def convert_to_2d(offsets, batch, max_seqlen): ) return offsets_2d - match qkv_layout: + match config.qkv_layout: case NVTE_QKV_Layout.NVTE_T3HD: kv_max_seqlen = q_max_seqlen = q.shape[-4] kv_batch = q_batch = reduce(operator.mul, q.shape[:-4]) @@ -472,66 +481,27 @@ def convert_to_2d(offsets, batch, max_seqlen): q_seq_offsets, k_seq_offsets, seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training, - max_segments_per_seq=max_segments_per_seq, + config=config, ) return output, softmax_aux, rng_state @staticmethod - def batcher( - batched_args, - batch_dims, - *, - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, - ): + def batcher(batched_args, batch_dims, *, config): check_valid_batch_dims(batch_dims) assert FusedAttnFwdPrimitive.outer_primitive is not None q_bdim, *_, seed_bdim = batch_dims out_bdims = q_bdim, q_bdim, seed_bdim return ( - FusedAttnFwdPrimitive.outer_primitive.bind( - *batched_args, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training, - max_segments_per_seq=max_segments_per_seq, - ), + FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, config=config), out_bdims, ) @staticmethod - def infer_sharding_from_operands( - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, - mesh, - arg_infos, - result_infos, - ): - del attn_bias_type, attn_mask_type, scaling_factor - del dropout_probability, is_training, max_segments_per_seq, result_infos + def infer_sharding_from_operands(config, mesh, arg_infos, result_infos): + del result_infos q_spec = get_padded_spec(arg_infos[0]) - k_spec = get_padded_spec(arg_infos[1]) - match qkv_layout: + match config.qkv_layout: case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD: # q_spec = (...batch, q_seqlen, head, hidden) out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:])) @@ -543,33 +513,22 @@ def infer_sharding_from_operands( # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden) out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-4]) + mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) ) case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD: # q_spec = (...batch, q_seqlen, head, hidden) # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden) out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3]) + mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) ) case _: - raise ValueError(f"Unsupported {qkv_layout=}") + raise ValueError(f"Unsupported {config.qkv_layout=}") rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) return (out_sharding, softmax_aux_sharding, rng_state_sharding) @staticmethod - def partition( - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, - mesh, - arg_infos, - result_infos, - ): + def partition(config, mesh, arg_infos, result_infos): out_sharding = result_infos[0].sharding softmax_aux_sharding = result_infos[1].sharding rng_state_sharding = seed_sharding = NamedSharding( @@ -577,16 +536,7 @@ def partition( ) arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) - impl = partial( - FusedAttnFwdPrimitive.impl, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training, - max_segments_per_seq=max_segments_per_seq, - ) + impl = partial(FusedAttnFwdPrimitive.impl, config=config) return mesh, impl, out_shardings, arg_shardings @@ -600,7 +550,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): name = "te_fused_attn_backward" multiple_results = True - impl_static_args = (12, 13, 14, 15, 16, 17, 18) + impl_static_args = (12,) inner_primitive = None outer_primitive = None @@ -619,13 +569,7 @@ def abstract( _q_seq_offsets, _k_seq_offsets, *, - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, + config, ): """ Fused attention bwd abstract @@ -641,10 +585,10 @@ def abstract( assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( - FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) + FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) ) - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: bias_batch = bias_heads = 0 else: *bias_batch_shape, bias_heads, _, _ = bias_aval.shape @@ -662,15 +606,15 @@ def abstract( num_gqa_groups, bias_heads, head_dim, - scaling_factor, - dropout_probability, - attn_bias_type, - attn_mask_type, - qkv_layout, + config.scaling_factor, + config.dropout_probability, + config.attn_bias_type, + config.attn_mask_type, + config.qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), - is_training, + config.is_training, deterministic, - max_segments_per_seq, + config.max_segments_per_seq, ) dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) @@ -707,13 +651,7 @@ def lowering( q_seq_offsets, k_seq_offsets, *, - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, + config, ): """ Fused attention bwd lowering rules @@ -743,12 +681,12 @@ def lowering( q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( - FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) + FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) ) input_batch = reduce(operator.mul, batch_shape) - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: bias_batch = bias_heads = 0 else: *bias_batch_shape, bias_heads, _, _ = bias_aval.shape @@ -765,16 +703,16 @@ def lowering( num_gqa_groups, bias_heads, head_dim, - max_segments_per_seq, + config.max_segments_per_seq, wkspace_aval.size, - scaling_factor, - dropout_probability, - attn_bias_type, - attn_mask_type, - qkv_layout, + config.scaling_factor, + config.dropout_probability, + config.attn_bias_type, + config.attn_mask_type, + config.qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - is_training, + config.is_training, not FusedAttnHelper.is_non_deterministic_allowed(), ) @@ -796,17 +734,11 @@ def impl( kv_seqlen, q_seq_offsets, k_seq_offsets, - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, + config, ): assert FusedAttnBwdPrimitive.inner_primitive is not None - if nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD: + if nvte_get_qkv_format(config.qkv_layout) == NVTE_QKV_Format.NVTE_THD: def _fix_len_take(x, condition, fill_value=-1): x_shape = x.shape @@ -825,7 +757,7 @@ def convert_to_2d(offsets, batch, max_seqlen): ) return offsets_2d - match qkv_layout: + match config.qkv_layout: case NVTE_QKV_Layout.NVTE_T3HD: kv_max_seqlen = q_max_seqlen = q.shape[-4] kv_batch = q_batch = reduce(operator.mul, q.shape[:-4]) @@ -882,63 +814,25 @@ def convert_to_2d(offsets, batch, max_seqlen): kv_cu_seqlen, q_seq_offsets, k_seq_offsets, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training, - max_segments_per_seq=max_segments_per_seq, + config=config, ) return dq, dk, dv, dbias @staticmethod - def batcher( - batched_args, - batch_dims, - *, - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, - ): + def batcher(batched_args, batch_dims, *, config): check_valid_batch_dims(batch_dims) assert FusedAttnBwdPrimitive.outer_primitive is not None q_bdim, k_bdim, v_bdim, *_ = batch_dims out_bdims = q_bdim, k_bdim, v_bdim, q_bdim return ( - FusedAttnBwdPrimitive.outer_primitive.bind( - *batched_args, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training, - max_segments_per_seq=max_segments_per_seq, - ), + FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config), out_bdims, ) @staticmethod - def infer_sharding_from_operands( - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, - mesh, - arg_infos, - result_infos, - ): - del attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, max_segments_per_seq - del dropout_probability, is_training, result_infos + def infer_sharding_from_operands(config, mesh, arg_infos, result_infos): + del config, result_infos q_spec = get_padded_spec(arg_infos[0]) k_spec = get_padded_spec(arg_infos[1]) v_spec = get_padded_spec(arg_infos[2]) @@ -950,18 +844,7 @@ def infer_sharding_from_operands( return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) @staticmethod - def partition( - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, - mesh, - arg_infos, - result_infos, - ): + def partition(config, mesh, arg_infos, result_infos): del result_infos q_spec = get_padded_spec(arg_infos[0]) k_spec = get_padded_spec(arg_infos[1]) @@ -1001,16 +884,10 @@ def sharded_impl( kv_cu_seqlen, q_seq_offsets, k_seq_offsets, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training, - max_segments_per_seq=max_segments_per_seq, + config=config, ) global_dbias = local_dbias - if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) return local_dq, local_dk, local_dv, global_dbias @@ -1020,6 +897,378 @@ def sharded_impl( register_primitive(FusedAttnBwdPrimitive) +@dataclass(frozen=True) +class _FusedAttnCPWithAllGatherHelper: + """Helper class to assist with running the all-gather strategy for CP attention.""" + + mesh: jax.sharding.Mesh + config: _FusedAttnConfig + + def check_supported(self): + """Checks if the context parallel implementation is supported by the given arguments.""" + header = "Context parallel fused attention" + + allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD] + assert self.config.qkv_layout in allowed_layouts, ( + f"{header} only supports layouts: {','.join([str(x) for x in allowed_layouts])} got:" + f" {self.config.qkv_layout}" + ) + + assert ( + self.config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS + ), f"{header} does not support bias got: {self.config.attn_bias_type}" + + allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK] + assert self.config.attn_mask_type in allowed_masks, ( + f"{header} only supports masking types: " + f" {','.join([str(x) for x in allowed_masks])} got: {self.config.attn_mask_type}" + ) + + assert self.config.max_segments_per_seq == 1, ( + f"{header} only supports max_segments_per_seq == 1 got:" + f" {self.config.max_segments_per_seq}" + ) + assert self.config.dropout_probability == 0.0, f"{header} does not support dropout" + + def get_adjusted_mask(self): + """Converts the mask for context parallelism.""" + if self.config.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK: + return NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK + return self.config.attn_mask_type + + def all_gather_kv(self, k, v): + """Performs a all-gather of k and v over context parallel ranks.""" + + def ag(x): + return lax_paral_op( + x, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True + ) + + match self.config.qkv_layout: + case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: + return ag(k), v + case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: + return ag(k), ag(v) + + return k, v # fall through + + def reduce_scatter_dkv(self, dk, dv): + """Performs a reduce-scatter of dk and dv over context parallel ranks.""" + + def rs(x): + return lax_paral_op( + x, + lax.psum_scatter, + self.config.cp_axis, + mesh=self.mesh, + scatter_dimension=1, + tiled=True, + ) + + match self.config.qkv_layout: + case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: + return rs(dk), dv + case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: + return rs(dk), rs(dv) + + return dk, dv # fall through + + def kv_seqlens_for_rank(self, cp_rank, kv_max_seqlen, kv_seqlen_per_subrank): + """Returns sequence lengths of KV to use for each sub rank of the given cp_rank. + + Example: CP=4, MaxLen = 1024, Unbalanced + cp_rank 0: [128, 256] + cp_rank 1: [384, 512] + cp_rank 2: [640, 768] + cp_rank 3: [896, 1024] + + Example: CP=4, MaxLen = 1024, Balanced + cp_rank 0: [128, 1024] + cp_rank 1: [256, 896] + cp_rank 2: [384, 768] + cp_rank 3: [512, 640] + """ + if self.config.context_parallel_load_balanced: + kv_seq_this_rank = [ + (cp_rank + 1) * kv_seqlen_per_subrank, + kv_max_seqlen - cp_rank * kv_seqlen_per_subrank, + ] + else: + kv_seq_this_rank = [ + (cp_rank * 2 + 1) * kv_seqlen_per_subrank, + (cp_rank * 2 + 2) * kv_seqlen_per_subrank, + ] + return kv_seq_this_rank + + def slice_kv(self, k, v, slice_seq_len): + """Slices k and v tensors to a sequence length of slice_seq_len.""" + + def sliced(x): + return lax.dynamic_slice_in_dim(x, 0, slice_seq_len, axis=1) + + match self.config.qkv_layout: + case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: + return sliced(k), v + case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: + return sliced(k), sliced(v) + + return k, v # fall through + + def pad_kv(self, dk, dv, pad_seq_len): + """Pads dk and dv tensors to a sequence length of pad_seq_len.""" + + def pad(x, npad): + return jnp.pad(x, npad, "constant", constant_values=0.0) + + match self.config.qkv_layout: + case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: + npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0], [0, 0]] + return pad(dk, npad), dv + case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: + npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0]] + return pad(dk, npad), pad(dv, npad) + + return dk, dv # fall through + + +class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): + """ + Fused Attention Forward with Context Parallelism Primitive + + This context parallel implementation uses all-gather to collect KV inputs from context parallel ranks. + """ + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + # Call base implementation for non-context parallel mesh to avoid unecessary work. + is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 + if not is_context_parallel: + return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos) + + helper = _FusedAttnCPWithAllGatherHelper(mesh, config) + helper.check_supported() + + out_sharding = result_infos[0].sharding + softmax_aux_sharding = result_infos[1].sharding + rng_state_sharding = seed_sharding = NamedSharding( + mesh, PartitionSpec(get_all_mesh_axes(), None) + ) + arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) + out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) + + def impl(q, k, v, bias, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets, seed): + + cp_size = get_mesh_axis_size(config.cp_axis, mesh) + cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) + + # cuDNN does not support right-aligned masking with dynamic sequence length padding. + # Therefore we must explicitly instantiate each CP rank slicing and use a runtime switch + # to select the appropriate computation. Each case generates a [..., SEQ/CP, ..] tensor + # meeting the expectation of the SPMD model. + # TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding + # mask/sequence length tensor to avoid this unrolled loop. + def _cross_attn(idx, q, k, v, bias, q_seqlen, kv_seqlen, seed): + kv_max_seqlen = k.shape[1] + kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2) + assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size" + + q_split = jnp.split(q, 2, axis=1) + + kv_seqlens_for_rank = helper.kv_seqlens_for_rank( + idx, kv_max_seqlen, kv_seqlen_per_subrank + ) + + results = [] + for sub_idx in range(2): + if config.attn_mask_type == NVTE_Mask_Type.NVTE_NO_MASK: + k_unmasked, v_unmasked = k, v # full kv used for unmasked + else: + k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx]) + + q_seqlen_for_step = q_seqlen / (cp_size * 2) + num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx] + kv_seqlen_for_step = (kv_seqlen / (cp_size * 2)) * num_kv_chunks + + output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl( + q_split[sub_idx], + k_unmasked, + v_unmasked, + bias, + q_seqlen_for_step, + kv_seqlen_for_step, + q_seq_offsets, + k_seq_offsets, + seed, + config=config, + ) + results.append((output, softmax_aux, rng_state)) + + output = jnp.concatenate((results[0][0], results[1][0]), axis=1) + softmax_aux = jnp.concatenate((results[0][1], results[1][1]), axis=2) + rng_state = results[1][2] # Use the final RNG state + + return output, softmax_aux, rng_state + + k_ag, v_ag = helper.all_gather_kv(k, v) + + functions = [ + partial(_cross_attn, idx, q, k_ag, v_ag, bias, q_seqlen, kv_seqlen, seed) + for idx in range(cp_size) + ] + + return lax.switch(cp_rank, functions) + + return mesh, impl, out_shardings, arg_shardings + + +register_primitive(FusedAttnCPWithAllGatherFwdPrimitive) + + +class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): + """ + Fused Attention Backward with Context Parallelism Primitive. + + This context parallel implementation uses all-gather to collect KV and dKV inputs from context parallel ranks. + The gradients are subsequently reduce-scattered back to each context parallel rank. + """ + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + # Call base implementation for non-context parallel mesh to avoid unecessary work. + is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 + if not is_context_parallel: + return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos) + + # Ensure we can support this configuration with context parallelism. + helper = _FusedAttnCPWithAllGatherHelper(mesh, config) + helper.check_supported() + + del result_infos + q_spec = get_padded_spec(arg_infos[0]) + k_spec = get_padded_spec(arg_infos[1]) + v_spec = get_padded_spec(arg_infos[2]) + bias_spec = get_padded_spec(arg_infos[3]) + dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) + dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) + dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) + + def impl( + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + ): + cp_size = get_mesh_axis_size(config.cp_axis, mesh) + cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) + + # See comment in FusedAttnCPFwdPrimitive.partition for why we define this function. + def _cross_attn_bwd( + idx, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen + ): + kv_max_seqlen = k.shape[1] + kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2) + assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size" + + q_split = jnp.split(q, 2, axis=1) + output_split = jnp.split(output, 2, axis=1) + doutput_split = jnp.split(doutput, 2, axis=1) + softmax_aux_split = jnp.split(softmax_aux, 2, axis=2) + + kv_seqlens_for_rank = helper.kv_seqlens_for_rank( + idx, kv_max_seqlen, kv_seqlen_per_subrank + ) + + results = [] + for sub_idx in range(2): + if config.attn_mask_type == NVTE_Mask_Type.NVTE_NO_MASK: + k_unmasked, v_unmasked = k, v # full kv used for unmasked + else: + k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx]) + + q_seqlen_for_step = q_seqlen // (cp_size * 2) + num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx] + kv_seqlen_for_step = (kv_seqlen // (cp_size * 2)) * num_kv_chunks + + dq_local, dk_local, dv_local, dbias_local = FusedAttnBwdPrimitive.impl( + q_split[sub_idx], + k_unmasked, + v_unmasked, + bias, + softmax_aux_split[sub_idx], + rng_state, + output_split[sub_idx], + doutput_split[sub_idx], + q_seqlen_for_step, + kv_seqlen_for_step, + q_seq_offsets, + k_seq_offsets, + config=config, + ) + + # pad dk/dv to be unsliced shape so we can reduce scatter over all ranks. + if config.attn_mask_type != NVTE_Mask_Type.NVTE_NO_MASK: + pad_length = kv_max_seqlen - kv_seqlens_for_rank[sub_idx] + dk_local, dv_local = helper.pad_kv(dk_local, dv_local, pad_length) + + results.append((dq_local, dk_local, dv_local, dbias_local)) + + dq_local = jnp.concatenate((results[0][0], results[1][0]), axis=1) + dk_local_pad = results[0][1] + results[1][1] + dv_local_pad = results[0][2] + results[1][2] + return dq_local, dk_local_pad, dv_local_pad, results[1][3] + + k_ag, v_ag = helper.all_gather_kv(k, v) + + functions = [ + partial( + _cross_attn_bwd, + idx, + q, + k_ag, + v_ag, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + ) + for idx in range(cp_size) + ] + + dq, dk_local, dv_local, dbias = lax.switch(cp_rank, functions) + dk, dv = helper.reduce_scatter_dkv(dk_local, dv_local) + + return dq, dk, dv, dbias + + return mesh, impl, out_shardings, arg_shardings + + +register_primitive(FusedAttnCPWithAllGatherBwdPrimitive) + + +def _maybe_context_parallel_axis(cp_axis: str): + if not cp_axis: + gmr = global_mesh_resource() + if gmr is not None: + cp_axis = gmr.cp_resource + else: + cp_axis = "" + return cp_axis + + def fused_attn_fwd( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], @@ -1035,6 +1284,8 @@ def fused_attn_fwd( dropout_probability: float, is_training: bool, max_segments_per_seq: int, + context_parallel_causal_load_balanced: bool = False, + context_parallel_axis: str = "", ) -> jnp.ndarray: """ Perform the forward pass of with cuDNN fused attention implementations. @@ -1063,6 +1314,9 @@ def fused_attn_fwd( scaling_factor (float): Scaling factor for the attention scores. dropout_probability (float): Dropout probability to apply during attention. is_training (bool): Flag indicating whether the model is in training mode. + context_parallel_causal_load_balanced (bool): + Indicates the sequences are ordered for causal mask load balancing when running context parallelism. + context_parallel_axis (str): The name of the context parallel axis. Returns: (jnp.ndarray): The output tensor from the fused attention. """ @@ -1094,14 +1348,7 @@ def fused_attn_fwd( assert bias is None bias = jnp.zeros(0, dtype=qkv[0].dtype) - return FusedAttnFwdPrimitive.outer_primitive.bind( - *qkv_for_primitive, - bias, - q_seqlen, - kv_seqlen, - q_seq_offsets if is_ragged else _not_used, - kv_seq_offsets if is_ragged else _not_used, - seed, + fused_config = _FusedAttnConfig( attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, qkv_layout=qkv_layout, @@ -1109,6 +1356,19 @@ def fused_attn_fwd( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, + context_parallel_load_balanced=context_parallel_causal_load_balanced, + cp_axis=_maybe_context_parallel_axis(context_parallel_axis), + ) + + return FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive.bind( + *qkv_for_primitive, + bias, + q_seqlen, + kv_seqlen, + q_seq_offsets if is_ragged else _not_used, + kv_seq_offsets if is_ragged else _not_used, + seed, + config=fused_config, ) @@ -1130,6 +1390,8 @@ def fused_attn_bwd( dropout_probability: float, is_training: bool, max_segments_per_seq: int, + context_parallel_causal_load_balanced: bool = False, + context_parallel_axis: str = "", ): """ Perform the backward pass of the cuDNN fused attention implementations. @@ -1159,7 +1421,9 @@ def fused_attn_bwd( scaling_factor (float): Scaling factor for the attention scores. dropout_probability (float): Dropout probability to apply during attention. is_training (bool): Flag indicating whether the model is in training mode. - + context_parallel_causal_load_balanced (bool): + Indicates the sequences are ordered for causal mask load balancing when running context parallelism. + context_parallel_axis (str): The name of the context parallel axis. Returns: Tuple[jnp.ndarray, ...], jnp.ndarray: - The first tuple contains the gradients with respect to the input `qkv` tensors in the @@ -1194,7 +1458,19 @@ def fused_attn_bwd( assert bias is None bias = jnp.zeros(0, dtype=qkv[0].dtype) - *qkv_grads, bias_grad = FusedAttnBwdPrimitive.outer_primitive.bind( + fused_config = _FusedAttnConfig( + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + max_segments_per_seq=max_segments_per_seq, + context_parallel_load_balanced=context_parallel_causal_load_balanced, + cp_axis=_maybe_context_parallel_axis(context_parallel_axis), + ) + + *qkv_grads, bias_grad = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive.bind( *qkv_for_primitive, bias, softmax_aux, @@ -1205,12 +1481,6 @@ def fused_attn_bwd( kv_seqlen, q_seq_offsets if is_ragged else _not_used, kv_seq_offsets if is_ragged else _not_used, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training, - max_segments_per_seq=max_segments_per_seq, + config=fused_config, ) return tuple(qkv_grads[: len(qkv)]), bias_grad diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 0a2172bb1b..14f449a76b 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -100,7 +100,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) - .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK); + .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) + .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) + .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", + NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 586e1a70c9..a14a8384cf 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -20,6 +20,7 @@ BATCH_AXES = "nvte_batch" SEQLEN_AXES = "nvte_seqlen" SEQLEN_TP_AXES = "nvte_seqlen_tp" +SEQLEN_CP_AXES = "nvte_seqlen_cp" HEAD_AXES = "nvte_head" HIDDEN_AXES = "nvte_hidden" HIDDEN_TP_AXES = "nvte_hidden_tp" @@ -65,6 +66,7 @@ def get_sharding_map_logic_axis_to_mesh_axis(): BATCH_AXES: batch_dim_rule, SEQLEN_AXES: None, SEQLEN_TP_AXES: gsr.tp_resource, + SEQLEN_CP_AXES: gsr.cp_resource, HEAD_AXES: gsr.tp_resource, HIDDEN_AXES: None, HIDDEN_TP_AXES: gsr.tp_resource, @@ -131,13 +133,15 @@ def get_padded_spec(spec, ndim): return spec + (None,) * (ndim - len(spec)) -def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str, mesh: jax.sharding.Mesh): +def lax_paral_op( + x: jnp.array, ops: Callable, mesh_resource: str, mesh: jax.sharding.Mesh, **kwargs +): """ A wrapper function to invoke lax.p* operations, like psum. """ if mesh_resource is not None: _, resource = _get_mesh_info(mesh_resource, mesh) - return ops(x, resource) + return ops(x, resource, **kwargs) return x @@ -148,6 +152,33 @@ def num_of_devices(): return len(jax.devices()) +def get_mesh_axis_size(axis, mesh=None): + """ + Get the axis size of the given mesh. + If the mesh is None, it would be replaced + by the global mesh. + """ + if mesh is None: + mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh + + if axis is None: + return 1 + + assert axis in mesh.shape, f"{axis} is not a axis of the given mesh {mesh.shape}" + return mesh.shape[axis] + + +def get_mesh_axis_rank(axis: str, mesh=None): + """ + Gets the local axis rank of the `axis` of the array. + If the mesh is None the rank is 0. + """ + if mesh is None: + return 0 + _, axis_name = _get_mesh_info(axis, mesh) + return jax.lax.axis_index(axis_name) + + @dataclass class MeshResource: """ @@ -168,12 +199,16 @@ class MeshResource: pp_resource : str, default = None The axis name in Mesh used to split model layers. along. If it is None, then pipeline parallelism is disabled. + cp_resource : str, default = None + The axis name in Mesh used to split sequence (context) dimensions along + in the attention. If it is None, then context parallelism is disabled. """ dp_resource: str = None tp_resource: str = None fsdp_resource: str = None pp_resource: str = None + cp_resource: str = None _GLOBAL_MESH_RESOURCE = MeshResource() From 44fd316f972a50be8105fdc41d2e5cd9efcf1a82 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 17 Sep 2024 09:58:55 -0500 Subject: [PATCH 39/50] [Common] Default CUDA_HOME to /usr/local/cuda when dynamically loading cuDNN and NVRTC (#1183) Defaulted CUDA_HOME/CUDA_PATH to /usr/local/cuda when attempting to dynamically load cuDNN and NVRTC Signed-off-by: Alp Dener --- transformer_engine/common/__init__.py | 36 +++++++++++++++------------ 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 46cfa9176a..4bcd1f8e27 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -47,20 +47,20 @@ def _get_sys_extension(): def _load_cudnn(): """Load CUDNN shared library.""" - + # Attempt to locate cuDNN in Python dist-packages lib_path = glob.glob( os.path.join( sysconfig.get_path("purelib"), f"nvidia/cudnn/lib/libcudnn.{_get_sys_extension()}.*[0-9]", ) ) - if lib_path: assert ( len(lib_path) == 1 ), f"Found {len(lib_path)} libcudnn.{_get_sys_extension()}.x in nvidia-cudnn-cuXX." return ctypes.CDLL(lib_path[0], mode=ctypes.RTLD_GLOBAL) + # Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH") if cudnn_home: libs = glob.glob(f"{cudnn_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True) @@ -68,13 +68,14 @@ def _load_cudnn(): if libs: return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) - cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") - if cuda_home: - libs = glob.glob(f"{cuda_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True) - libs.sort(reverse=True, key=os.path.basename) - if libs: - return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) + # Attempt to locate cuDNN in CUDA_HOME, CUDA_PATH or /usr/local/cuda + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" + libs = glob.glob(f"{cuda_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True) + libs.sort(reverse=True, key=os.path.basename) + if libs: + return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) + # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) @@ -91,14 +92,15 @@ def _load_library(): def _load_nvrtc(): """Load NVRTC shared library.""" - cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") - if cuda_home: - libs = glob.glob(f"{cuda_home}/**/libnvrtc.{_get_sys_extension()}*", recursive=True) - libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs)) - libs.sort(reverse=True, key=os.path.basename) - if libs: - return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) - + # Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" + libs = glob.glob(f"{cuda_home}/**/libnvrtc.{_get_sys_extension()}*", recursive=True) + libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs)) + libs.sort(reverse=True, key=os.path.basename) + if libs: + return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) + + # Attempt to locate NVRTC via ldconfig libs = subprocess.check_output("ldconfig -p | grep 'libnvrtc'", shell=True) libs = libs.decode("utf-8").split("\n") sos = [] @@ -109,6 +111,8 @@ def _load_nvrtc(): sos.append(lib.split(">")[1].strip()) if sos: return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL) + + # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise return ctypes.CDLL(f"libnvrtc.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) From 4fb25ccfea9e2ad1227fe3a82712849a2dbd5131 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Tue, 17 Sep 2024 10:17:57 -0700 Subject: [PATCH 40/50] Changed VERSION to 1.11.0 Signed-off-by: Przemyslaw Tredak --- build_tools/VERSION.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index 12790c22b4..1cac385c6c 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -1.11.0.dev0 +1.11.0 From 63fd8ac4ab924c96c4474ae83f3a8dc0efcd8456 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 20 Sep 2024 13:38:00 -0700 Subject: [PATCH 41/50] Allow downloading of model weights automatically (#1172) * allow tutorial to download the model weights automatically Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * allow users to provide weight cache directory Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Sudhakar Singh Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/examples/te_llama/te_llama.py | 9 ++- ...tutorial_accelerate_hf_llama_with_te.ipynb | 65 +++++++++++++------ docs/examples/te_llama/utils.py | 56 ++++++++++++++-- 3 files changed, 103 insertions(+), 27 deletions(-) diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py index cb384aa10c..4413bdfd00 100644 --- a/docs/examples/te_llama/te_llama.py +++ b/docs/examples/te_llama/te_llama.py @@ -102,8 +102,11 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k Custom method adapted from `from_pretrained` method in HuggingFace Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 """ - vanilla_model = cls(config).to(kwargs["torch_dtype"]) - is_local = os.path.isdir(pretrained_model_name_or_path) + # Before loading the model, set the default dtype for torch + torch.set_default_dtype(kwargs["torch_dtype"]) + + # Load the vanilla model weights + vanilla_model = cls(config) subfolder = "" variant = None if os.path.isfile( @@ -133,7 +136,7 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k else: raise AssertionError("Only sharded PyTorch ckpt format supported at the moment") - resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + resolved_archive_file, _ = get_checkpoint_shard_files( pretrained_model_name_or_path, archive_file, ) diff --git a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb index 57c1bf6601..7013e85ec6 100644 --- a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb +++ b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb @@ -247,15 +247,24 @@ "restart_jupyter_notebook()\n", "\n", "\n", - "# Import necessary packages and methods\n", + "# Import necessary packages, methods and variables\n", "from utils import *\n", "\n", "\n", - "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", - "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n", - "# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n", - "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "# Provide Huggingface Access Token\n", + "hyperparams.hf_access_token = \"\"\n", + "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "hyperparams.weights_cache_dir = \"\"\n", + "\n", + "# For Llama 2, uncomment this line (also set by default)\n", + "hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n", + "\n", + "# For Llama 3, uncomment this line\n", + "# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n", + "\n", "hyperparams.mixed_precision = \"bf16\"\n", "\n", "\n", @@ -554,7 +563,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "bdb34b91", "metadata": {}, "outputs": [ @@ -573,15 +582,24 @@ "restart_jupyter_notebook()\n", "\n", "\n", - "# Import necessary packages and methods\n", + "# Import necessary packages, methods and variables\n", "from utils import *\n", "\n", "\n", - "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", - "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n", - "# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n", - "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "# Provide Huggingface Access Token\n", + "hyperparams.hf_access_token = \"\"\n", + "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "hyperparams.weights_cache_dir = \"\"\n", + "\n", + "# For Llama 2, uncomment this line (also set by default)\n", + "hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n", + "\n", + "# For Llama 3, uncomment this line\n", + "# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n", + "\n", "hyperparams.mixed_precision = \"bf16\"\n", "\n", "\n", @@ -653,15 +671,24 @@ "restart_jupyter_notebook()\n", "\n", "\n", - "# Import necessary packages and methods\n", + "# Import necessary packages, methods and variables\n", "from utils import *\n", "\n", "\n", - "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", - "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n", - "# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n", - "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "# Provide Huggingface Access Token\n", + "hyperparams.hf_access_token = \"\"\n", + "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "hyperparams.weights_cache_dir = \"\"\n", + "\n", + "# For Llama 2, uncomment this line (also set by default)\n", + "hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n", + "\n", + "# For Llama 3, uncomment this line\n", + "# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n", + "\n", "hyperparams.mixed_precision = \"fp8\"\n", "\n", "\n", diff --git a/docs/examples/te_llama/utils.py b/docs/examples/te_llama/utils.py index b6b3683d4c..1aebe13afb 100644 --- a/docs/examples/te_llama/utils.py +++ b/docs/examples/te_llama/utils.py @@ -25,7 +25,10 @@ class HyperParameters: def __init__(self): self.mixed_precision = "bf16" - # self.model_name = "" # <== Add model weight location here + + # Set to Meta Llama 2 by default. + self.model_name = "meta-llama/Llama-2-7b-hf" + self.dataset_name = "timdettmers/openassistant-guanaco" self.dataset_text_field = "text" self.learning_rate = 1.41e-5 @@ -35,6 +38,10 @@ def __init__(self): self.num_warmup_steps = 5 self.num_training_steps = 10 + # This is either provided by the user or it will be set when the + # model weights are downloaded. + self.weights_cache_dir = "" + hyperparams = HyperParameters() @@ -76,13 +83,49 @@ def tokenize(element): return train_dataloader +def ensure_model_is_downloaded(hyperparams): + assert hyperparams.model_name in [ + "meta-llama/Meta-Llama-3-8B", + "meta-llama/Llama-2-7b-hf", + ], "Only Meta Llama 2 7B and Meta Llama 3 8B models are supported!" + + # Login using Huggingface Hub API + from huggingface_hub import login + + try: + login(hyperparams.hf_access_token) + except Exception as e: + if "Invalid token passed!" in str(e): + print( + "Please pass a valid HF Access Token! More info at" + " https://huggingface.co/docs/hub/en/security-tokens." + ) + else: + print(f"Exception is {e}") + + # Download the model if it doesn't exist + from huggingface_hub import snapshot_download + + supplied_cache_dir = ( + hyperparams.weights_cache_dir if hyperparams.weights_cache_dir != "" else None + ) + hyperparams.weights_cache_dir = snapshot_download( + repo_id=hyperparams.model_name, cache_dir=supplied_cache_dir + ) + + print(f"Model cache directory : {hyperparams.weights_cache_dir}") + + def init_baseline_model(hyperparams): + # Download and cache the weights + ensure_model_is_downloaded(hyperparams) + # Init the model - config = AutoConfig.from_pretrained(hyperparams.model_name) + config = AutoConfig.from_pretrained(hyperparams.weights_cache_dir) # make sure to use flash_attention to do iso comparison with TELlamaModel config._attn_implementation = "flash_attention_2" model = AutoModelForCausalLM.from_pretrained( - hyperparams.model_name, + hyperparams.weights_cache_dir, config=config, torch_dtype=torch.bfloat16, ) @@ -94,13 +137,16 @@ def init_baseline_model(hyperparams): def init_te_llama_model(hyperparams): + # Download and cache the weights + ensure_model_is_downloaded(hyperparams) + # Init the model from te_llama import TELlamaForCausalLM - config = AutoConfig.from_pretrained(hyperparams.model_name) + config = AutoConfig.from_pretrained(hyperparams.weights_cache_dir) config._attn_implementation = "flash_attention_2" model = TELlamaForCausalLM.from_pretrained_local( - hyperparams.model_name, + hyperparams.weights_cache_dir, config=config, torch_dtype=torch.bfloat16, ) From 6a2109fd6f8922b7fb6e58d05577e024fe9adf97 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Fri, 20 Sep 2024 16:05:05 -0700 Subject: [PATCH 42/50] Restore compatibility with Python 3.8 (#1189) * Restore compatibility with Python 3.8 Signed-off-by: Przemyslaw Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Przemyslaw Tredak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/pytorch/distributed.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index fdf65db21e..e9fb11e3b9 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -354,12 +354,8 @@ def backward( # Compute the forward pass. detached_inputs = detach_variable(inputs) - with ( - torch.enable_grad(), - ctx.recompute_ctx, - ctx.torch_gpu_amp_ctx, - ctx.torch_cpu_amp_ctx, - activation_recompute_forward(activation_recompute=True, recompute_phase=True), + with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward( + activation_recompute=True, recompute_phase=True ): outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) @@ -680,13 +676,9 @@ def checkpoint( torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx = _get_active_autocast_contexts() def recompute_fn(*args, **kwargs): - with ( - torch.autograd.enable_grad(), - te_recompute_ctx, - user_recompute_ctx, - torch_gpu_amp_forward_ctx, - torch_cpu_amp_forward_ctx, - ): + with torch.autograd.enable_grad(), ( + te_recompute_ctx + ), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx: function(*args, **kwargs) # Initialize a new checkpoint frame for each new forward pass. From d673e49fd141df5d64805f3252ca28c9f1ca5f7e Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 27 Sep 2024 11:33:51 -0700 Subject: [PATCH 43/50] [PyTorch] Fix detection of 3 in 3hd/h3d layouts (#1187) * fix detection of 3 in 3hd/h3d layouts Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * error out when invalid layout group is provided Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../pytorch/csrc/extensions/attention.cu | 36 +++++++++++++++---- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index fb1fc97a33..b2968a688d 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -95,9 +95,21 @@ std::vector fused_attn_fwd_qkvpacked( auto qkv_sizes = QKV.sizes().vec(); std::vector qkv_shape{qkv_sizes.begin(), qkv_sizes.end()}; std::vector q_shape; - for (auto i : qkv_shape) { - if (i != 3) { - q_shape.push_back(i); + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + int loc_3 = 0; + switch (layout_group) { + case NVTE_3HD: + loc_3 = qkv_sizes.size() - 3; + break; + case NVTE_H3D: + loc_3 = qkv_sizes.size() - 2; + break; + default: + NVTE_ERROR("Invalid QKV layout group."); + } + for (auto it = qkv_shape.begin(); it != qkv_shape.end(); ++it) { + if (it - qkv_shape.begin() != loc_3) { + q_shape.push_back(*it); } } std::vector o_shape{q_shape.begin(), q_shape.end()}; @@ -252,9 +264,21 @@ std::vector fused_attn_bwd_qkvpacked( auto qkv_sizes = QKV.sizes().vec(); std::vector qkv_shape{qkv_sizes.begin(), qkv_sizes.end()}; std::vector q_shape; - for (auto i : qkv_shape) { - if (i != 3) { - q_shape.push_back(i); + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + int loc_3 = 0; + switch (layout_group) { + case NVTE_3HD: + loc_3 = qkv_sizes.size() - 3; + break; + case NVTE_H3D: + loc_3 = qkv_sizes.size() - 2; + break; + default: + NVTE_ERROR("Invalid QKV layout group."); + } + for (auto it = qkv_shape.begin(); it != qkv_shape.end(); ++it) { + if (it - qkv_shape.begin() != loc_3) { + q_shape.push_back(*it); } } auto h = q_shape[q_shape.size() - 2]; From 9b9a75ac2cd1dcfaa0c82a4fa5024acf5ee6f6de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Fri, 27 Sep 2024 19:48:43 +0200 Subject: [PATCH 44/50] [PyTorch] Add GroupedLinear to the docs and fix typos (#1206) * Docs fixes Signed-off-by: Pawel Gadzinski * docs fix Signed-off-by: Pawel Gadzinski * docs fix Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Pawel Gadzinski Co-authored-by: Pawel Gadzinski --- docs/api/pytorch.rst | 3 +++ transformer_engine/pytorch/attention.py | 2 +- transformer_engine/pytorch/module/grouped_linear.py | 6 +++--- transformer_engine/pytorch/module/layernorm.py | 2 +- transformer_engine/pytorch/module/layernorm_linear.py | 4 ++-- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- transformer_engine/pytorch/module/linear.py | 6 +++--- transformer_engine/pytorch/module/rmsnorm.py | 2 +- transformer_engine/pytorch/transformer.py | 2 +- 9 files changed, 16 insertions(+), 13 deletions(-) diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index a210019dc1..b097f14475 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -9,6 +9,9 @@ pyTorch .. autoapiclass:: transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs) :members: forward, set_tensor_parallel_group +.. autoapiclass:: transformer_engine.pytorch.GroupedLinear(in_features, out_features, bias=True, **kwargs) + :members: forward, set_tensor_parallel_group + .. autoapiclass:: transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs) .. autoapiclass:: transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index f8ba46b2ea..bacadf2cd5 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -7856,7 +7856,7 @@ class MultiheadAttention(torch.nn.Module): bias : bool, default = `True` if set to `False`, the transformer layer will not learn any additive biases. device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will allocated. It is the user's + The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. qkv_format: str, default = `sbhd` diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 10c8d91551..0bad1306c3 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -528,11 +528,11 @@ class GroupedLinear(TransformerEngineBaseModule): used for initializing weights in the following way: `init_method(weight)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. get_rng_state_tracker : Callable, default = `None` - used to get the random number generator state tracker for initilizeing weights. + used to get the random number generator state tracker for initializing weights. rng_tracker_name : str, default = `None` the param passed to get_rng_state_tracker to get the specific rng tracker. device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will allocated. It is the user's + The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. @@ -548,7 +548,7 @@ class GroupedLinear(TransformerEngineBaseModule): `set_tensor_parallel_group(tp_group)` method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives. - parallel_mode : {None, 'Column', 'Row'}, default = `None` + parallel_mode : {None, 'column', 'row'}, default = `None` used to decide whether this GroupedLinear layer is Column Parallel Linear or Row Parallel Linear as described `here `_. When set to `None`, no communication is performed. diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index ec33ad2033..292fcd06de 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -110,7 +110,7 @@ class LayerNorm(torch.nn.Module): y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will allocated. It is the user's + The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. """ diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index da77879e06..92030a7f7a 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -816,7 +816,7 @@ class LayerNormLinear(TransformerEngineBaseModule): y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will allocated. It is the user's + The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. @@ -832,7 +832,7 @@ class LayerNormLinear(TransformerEngineBaseModule): `set_tensor_parallel_group(tp_group)` method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives. - parallel_mode : {None, 'Column', 'Row'}, default = `None` + parallel_mode : {None, 'column', 'row'}, default = `None` used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described `here `_. When set to `None`, no communication is performed. diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index b802c972d4..6d5609ccd2 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1193,7 +1193,7 @@ class LayerNormMLP(TransformerEngineBaseModule): y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will allocated. It is the user's + The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index a7be82ccf1..8e19a65a28 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -650,7 +650,7 @@ class Linear(TransformerEngineBaseModule): used for initializing weights in the following way: `init_method(weight)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. get_rng_state_tracker : Callable, default = `None` - used to get the random number generator state tracker for initilizeing weights. + used to get the random number generator state tracker for initializing weights. rng_tracker_name : str, default = `None` the param passed to get_rng_state_tracker to get the specific rng tracker. parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None @@ -662,7 +662,7 @@ class Linear(TransformerEngineBaseModule): names that end in `_weight` or `_bias`, so trailing underscores are stripped from any provided names. device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will allocated. It is the user's + The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. @@ -678,7 +678,7 @@ class Linear(TransformerEngineBaseModule): `set_tensor_parallel_group(tp_group)` method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives. - parallel_mode : {None, 'Column', 'Row'}, default = `None` + parallel_mode : {None, 'column', 'row'}, default = `None` used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described `here `_. When set to `None`, no communication is performed. diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index 969a468426..d5dc400206 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -120,7 +120,7 @@ class RMSNorm(torch.nn.Module): .. math:: y = \frac{x}{RMS_\varepsilon(x)} * (1 + \gamma) device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will allocated. It is the user's + The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. """ diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 958c7019ba..020d262be2 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -173,7 +173,7 @@ class TransformerLayer(torch.nn.Module): Type of activation used in MLP block. Options are: 'gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu' and 'srelu'. device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will allocated. It is the user's + The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. attn_input_format: {'sbhd', 'bshd'}, default = 'sbhd' From 404189766cfad8b6d6f0a393770fa12c3301693b Mon Sep 17 00:00:00 2001 From: Ryan Date: Tue, 17 Sep 2024 11:59:25 -0700 Subject: [PATCH 45/50] Allow specifying cmake setup directory (#1186) Allow specifying cmake directory Signed-off-by: Ryan Li Co-authored-by: Ryan Li --- build_tools/build_ext.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index f71cef08ea..af11ada34c 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -106,8 +106,12 @@ def run(self) -> None: if isinstance(ext, CMakeExtension): print(f"Building CMake extension {ext.name}") # Set up incremental builds for CMake extensions - setup_dir = Path(__file__).resolve().parent.parent - build_dir = setup_dir / "build" / "cmake" + build_dir = os.getenv("NVTE_CMAKE_BUILD_DIR") + if build_dir: + build_dir = Path(build_dir).resolve() + else: + root_dir = Path(__file__).resolve().parent.parent + build_dir = root_dir / "build" / "cmake" # Ensure the directory exists build_dir.mkdir(parents=True, exist_ok=True) From 536ac363590432ee660785e9fa9b1b02e57bf153 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 17 Sep 2024 12:34:05 -0700 Subject: [PATCH 46/50] Add docs for installing from PyPI (#1184) * Add PyPI install instructions Signed-off-by: Kirthi Shankar Sivamani * Review from @timmoon10 Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- README.rst | 10 +++++++++- docs/installation.rst | 11 +++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 25ed8af1de..6cc7eeae8a 100644 --- a/README.rst +++ b/README.rst @@ -174,7 +174,15 @@ To install the latest stable version of Transformer Engine, pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable -This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch). +This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch,paddle). + +Alternatively, the package can be directly installed from `Transformer Engine's PyPI `_, e.g. + +.. code-block:: bash + + pip install transformer_engine[pytorch] + +To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch,paddle]). Transformer Engine ships wheels for the core library as well as the PaddlePaddle extensions. Source distributions are shipped for the JAX and PyTorch extensions. From source ^^^^^^^^^^^ diff --git a/docs/installation.rst b/docs/installation.rst index 012f3303cb..9ac0ddf841 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -28,6 +28,17 @@ Transformer Engine library is preinstalled in the PyTorch container in versions on `NVIDIA GPU Cloud `_. +pip - from PyPI +----------------------- + +Transformer Engine can be directly installed from `our PyPI `_, e.g. + +.. code-block:: bash + + pip install transformer_engine[pytorch] + +To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch,paddle]). Transformer Engine ships wheels for the core library as well as the PaddlePaddle extensions. Source distributions are shipped for the JAX and PyTorch extensions. + pip - from GitHub ----------------------- From 1c209c9a1862a0ba09c36bfcc18139b5e48f34e8 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 18 Sep 2024 13:09:20 -0500 Subject: [PATCH 47/50] [PyTorch] Check network interface name when initializing Userbuffers (#1175) * Check if network interface name is valid and show useful warning message when initializing Userbuffers Signed-off-by: Alp Dener * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix formatting issue in warning message. Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Alp Dener * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Alp Dener Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/module/base.py | 30 +++++++++++++++++------ 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 3375b8ab7d..644af2c22c 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -138,15 +138,29 @@ def initialize_ub( ) if ifname is not None: - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - try: - hostname = socket.inet_ntoa( - fcntl.ioctl( - s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8")) - )[20:24] + # Make sure the ifname found in the environment is a valid network interface + if ifname in [name for _, name in socket.if_nameindex()]: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + hostname = socket.inet_ntoa( + fcntl.ioctl( + s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8")) + )[20:24] + ) + except OSError as err: + raise OSError(f"Invalid network interface: {ifname}") from err + finally: + s.close() + else: + ifname_warning = ( + f"'{ifname}' is not a valid network interface! `te.initialize_ub()` will" + " attempt to " + + "detect ranks on the same node by matching 'socket.gethostname()', which is " + + "known to fail on virtual clusters like Kubernetes. If Userbuffers " + + "initialization fails, please set the 'NVTE_UB_SOCKET_IFNAME' variable in " + + "your environment to the correct network interface." ) - except OSError as err: - raise OSError(f"Invalid network interface: {ifname}") from err + warnings.warn(ifname_warning, UserWarning) hostnames = [None for _ in range(world_size)] torch.distributed.all_gather_object(hostnames, hostname, world_group) From c3280716deb2fbba26e8a19951d4bde7f6c1660c Mon Sep 17 00:00:00 2001 From: Sangkug Lym Date: Tue, 24 Sep 2024 21:24:28 -0700 Subject: [PATCH 48/50] fix NVTE_UB_WITH_MPI read (#1194) * fix NVTE_UB_WITH_MPI read Signed-off-by: Sangkug Lym * Add default value Signed-off-by: Sangkug Lym --------- Signed-off-by: Sangkug Lym Co-authored-by: Kirthi Shankar Sivamani --- build_tools/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 3725e58c87..4563a0272a 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -88,7 +88,7 @@ def setup_pytorch_extension( # Libraries library_dirs = [] libraries = [] - if os.getenv("NVTE_UB_WITH_MPI"): + if bool(int(os.getenv("NVTE_UB_WITH_MPI", 0))): assert ( os.getenv("MPI_HOME") is not None ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" From 458c7de038ed34bdaf471ced4e3162a28055def7 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Date: Fri, 27 Sep 2024 11:56:03 -0700 Subject: [PATCH 49/50] Fix CP unit test on A100 and L40s (#1211) skip FP8 CP tests if hardware does not support FP8 Signed-off-by: Xiaowei Ren --- tests/pytorch/fused_attn/test_fused_attn_with_cp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index d6358d1062..c1c18ffe47 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -112,6 +112,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): pytest.skip("THD format is only supported on sm90+!") if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0): pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") + if dtype == "fp8" and get_device_compute_capability() < (9, 0): + pytest.skip("FP8 attention is only supported on sm90+!") config = model_configs_fused_attn[model] if qkv_format == "thd" and config.num_heads != config.num_gqa_groups: From 4df84889cb5743113b0fa59839b941486df16ace Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Mon, 30 Sep 2024 19:33:23 -0700 Subject: [PATCH 50/50] Removed the unused options from GroupedLinear docs and fixed the bug with offsets (#1220) * Removing the unused options from GroupedLinear docs and fixing the bug with offsets Signed-off-by: Przemyslaw Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * offsets -> fp8_meta_offsets Signed-off-by: Przemyslaw Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Przemyslaw Tredak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../pytorch/module/grouped_linear.py | 91 ++++++------------- 1 file changed, 27 insertions(+), 64 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 0bad1306c3..14edd64249 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -44,18 +44,6 @@ __all__ = ["GroupedLinear"] -""" -The offset for fp8_meta_index. -_GEMM_INPUT = 0 -_GEMM_WEIGHT = num_gemms -_GEMM_OUTPUT = 2 * num_gemms -Must be properly set in GroupedLinear's initialization. -""" -_GEMM_INPUT = 0 -_GEMM_WEIGHT = 0 -_GEMM_OUTPUT = 0 -_GRAD_OUTPUT = 0 - class _GroupedLinear(torch.autograd.Function): """GroupedLinear semi-top level module @@ -74,12 +62,9 @@ def forward( fp8_meta: Dict[str, Any], fuse_wgrad_accumulation: bool, cpu_offloading: bool, - tp_group: Union[dist_group_type, None], - tp_size: int, sequence_parallel: bool, - tensor_parallel: bool, activation_dtype: torch.dtype, - parallel_mode: Union[str, None], + fp8_meta_offsets: Dict[str, int], is_grad_enabled: bool, weights_fp8: List[Union[Float8Tensor, None]], *weights_and_biases: Union[Float8Tensor, torch.Tensor, None], @@ -103,7 +88,6 @@ def forward( inputmats_t = [] inputmat_scale_inv = None - global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) inputmat_scale_inv = torch.empty([num_gemms], dtype=torch.float32, device=inp.device) @@ -114,7 +98,9 @@ def forward( and not sequence_parallel ): # FP8 input for forward, FP8 input transpose for backward wgrad - indices = list(range(_GEMM_INPUT, _GEMM_INPUT + num_gemms)) + indices = list( + range(fp8_meta_offsets["input"], fp8_meta_offsets["input"] + num_gemms) + ) inputmats, inputmats_t = fp8_multi_cast_transpose_fused( inputmats_no_fp8, fp8_meta["scaling_fwd"], @@ -130,7 +116,7 @@ def forward( cast_to_fp8( inputmats_no_fp8[i], fp8_meta["scaling_fwd"], - _GEMM_INPUT + i, + fp8_meta_offsets["input"] + i, fp8_dtype_forward, scale_inv=inputmat_scale_inv, ) @@ -194,14 +180,14 @@ def forward( for i in range(num_gemms): # amax of input amin, amax = inputmats[i].aminmax() - fp8_meta["scaling_fwd"].amax_history[0][_GEMM_INPUT + i] = torch.max( - -amin, amax - ).float() + fp8_meta["scaling_fwd"].amax_history[0][fp8_meta_offsets["input"] + i] = ( + torch.max(-amin, amax).float() + ) # amax of weight amin, amax = weights[i].aminmax() - fp8_meta["scaling_fwd"].amax_history[0][_GEMM_WEIGHT + i] = torch.max( - -amin, amax - ).float() + fp8_meta["scaling_fwd"].amax_history[0][fp8_meta_offsets["weight"] + i] = ( + torch.max(-amin, amax).float() + ) out = torch.empty( [sum(m_splits), weights[0].size(0)], @@ -266,11 +252,8 @@ def forward( ctx.is_first_microbatch = is_first_microbatch ctx.use_bias = use_bias ctx.sequence_parallel = sequence_parallel - ctx.tensor_parallel = tensor_parallel ctx.inp_shape = inp.shape - ctx.parallel_mode = parallel_mode - ctx.tp_group = tp_group - ctx.tp_size = tp_size + ctx.fp8_meta_offsets = fp8_meta_offsets ctx.requires_dgrad = inp.requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False if ctx.fp8 and requires_grad(inp, weights[0], biases[0]): @@ -300,7 +283,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], w.main_grad = main_grads[i] weights[i] = w - global _GEMM_INPUT, _GEMM_WEIGHT, _GRAD_OUTPUT # preprocess grad_output grad_output = grad_output.contiguous() grad_output_mats = torch.split( @@ -318,13 +300,18 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], fp8_cast_transpose_bgrad_fused( grad_output_mats[i], ctx.fp8_meta["scaling_bwd"], - _GRAD_OUTPUT + i, + ctx.fp8_meta_offsets["grad_output"] + i, fp8_dtype_backward, ) ) else: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - indices = list(range(_GRAD_OUTPUT, _GRAD_OUTPUT + ctx.num_gemms)) + indices = list( + range( + ctx.fp8_meta_offsets["grad_output"], + ctx.fp8_meta_offsets["grad_output"] + ctx.num_gemms, + ) + ) grad_output_c, grad_output_t = fp8_multi_cast_transpose_fused( grad_output_mats, ctx.fp8_meta["scaling_bwd"], @@ -338,7 +325,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output_c[i] = cast_to_fp8( grad_output_mats[i], ctx.fp8_meta["scaling_bwd"], - _GRAD_OUTPUT + i, + ctx.fp8_meta_offsets["grad_output"] + i, fp8_dtype_backward, ) @@ -363,7 +350,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], weights_fp8[0]._fp8_dtype, grad_output_c, ctx.fp8_meta["scaling_bwd"].scale_inv, - _GRAD_OUTPUT, + ctx.fp8_meta_offsets["grad_output"], fp8_dtype_backward, [dgrad], ctx.activation_dtype, @@ -416,7 +403,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], fp8_dtype_forward, grad_output_t, ctx.fp8_meta["scaling_bwd"].scale_inv, - _GRAD_OUTPUT, + ctx.fp8_meta_offsets["grad_output"], fp8_dtype_backward, wgrad_list, ctx.activation_dtype, @@ -497,12 +484,9 @@ def handle_custom_ddp_from_mcore(w, wgrad): None, # fp8_meta None, # fuse_wgrad_accumulation None, # cpu_offloading - None, # tp_group - None, # tp_size None, # sequence_parallel - None, # tensor_parallel None, # activation_dtype - None, # parallel_mode + None, # fp8_meta_offsets None, # is_grad_enabled None, # weights_fp8 *wgrad_list, @@ -536,23 +520,6 @@ class GroupedLinear(TransformerEngineBaseModule): responsibility to ensure all parameters are moved to the GPU before running the forward pass. - Parallelism parameters - ---------------------- - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - tp_group : ProcessGroup, default = `None` - tensor parallel process group. - tp_size : int, default = 1 - used as TP (tensor parallel) world size when TP groups are not formed during - initialization. In this case, users must call the - `set_tensor_parallel_group(tp_group)` method on the initialized module before the - forward pass to supply the tensor parallel group needed for tensor and sequence - parallel collectives. - parallel_mode : {None, 'column', 'row'}, default = `None` - used to decide whether this GroupedLinear layer is Column Parallel Linear or Row - Parallel Linear as described `here `_. - When set to `None`, no communication is performed. - Optimization parameters ----------------------- fuse_wgrad_accumulation : bool, default = 'False' @@ -613,8 +580,7 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name - global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT - _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT = 0, num_gemms, 2 * num_gemms + self._offsets = {"input": 0, "weight": num_gemms, "output": 2 * num_gemms, "grad_output": 0} if tp_group is None: self.tp_size = tp_size @@ -651,7 +617,7 @@ def __init__( ), init_fn=init_method, get_rng_state_tracker=get_rng_state_tracker, - fp8_meta_index=_GEMM_WEIGHT + i, + fp8_meta_index=self._offsets["weight"] + i, ) # Construct bias parameters if needed @@ -774,7 +740,7 @@ def forward( weight_tensors_fp8[i] = self.get_fp8_workspace( tensor=weight_tensors[i], fp8_meta_forward=True, - fp8_meta_index=_GEMM_WEIGHT + i, + fp8_meta_index=self._offsets["weight"] + i, cache_name=(None if is_first_microbatch is None else f"weight{i}"), update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, @@ -798,12 +764,9 @@ def forward( self.fp8_meta, self.fuse_wgrad_accumulation, CPUOffloadEnabled, - self.tp_group, - self.tp_size, self.sequence_parallel, - self.tp_size > 1, self.activation_dtype, - self.parallel_mode, + self._offsets, torch.is_grad_enabled(), weight_tensors_fp8, *weight_tensors,