diff --git a/.github/workflows/deploy_nightly_docs.yml b/.github/workflows/deploy_nightly_docs.yml index 3820dbde50..cd68019c8f 100644 --- a/.github/workflows/deploy_nightly_docs.yml +++ b/.github/workflows/deploy_nightly_docs.yml @@ -16,7 +16,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Download artifact - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4.1.7 with: name: "te_docs" path: "html" diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index a4fdbdfdfd..d2bd865a8f 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -31,7 +31,7 @@ jobs: run: | sudo apt-get update sudo apt-get install pip -y - pip install torch + pip install torch numpy export PYTHON_ONLY=1 export TE_PATH=. bash ./qa/L0_pytorch_lint/test.sh diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index cd47fa9a54..c25aa863ad 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -33,6 +33,9 @@ jobs: || github.actor == 'Oleg-Goncharov' || github.actor == 'phu0ngng' || github.actor == 'xrennvidia' + || github.actor == 'yaox12' + || github.actor == 'huanghua1994' + || github.actor == 'mgoldfarb-nvidia' ) steps: - name: Check if comment is issued by authorized person diff --git a/README.rst b/README.rst index 25ed8af1de..6cc7eeae8a 100644 --- a/README.rst +++ b/README.rst @@ -174,7 +174,15 @@ To install the latest stable version of Transformer Engine, pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable -This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch). +This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch,paddle). + +Alternatively, the package can be directly installed from `Transformer Engine's PyPI `_, e.g. + +.. code-block:: bash + + pip install transformer_engine[pytorch] + +To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch,paddle]). Transformer Engine ships wheels for the core library as well as the PaddlePaddle extensions. Source distributions are shipped for the JAX and PyTorch extensions. From source ^^^^^^^^^^^ diff --git a/benchmarks/attention/benchmark_attention.py b/benchmarks/attention/benchmark_attention.py index bfd7bf8471..cff7c65fbc 100644 --- a/benchmarks/attention/benchmark_attention.py +++ b/benchmarks/attention/benchmark_attention.py @@ -156,7 +156,7 @@ def parse_results(per_cudnn, per_flash, model): df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] = t_cudnn_avg.sum() / 1e6 if per_flash > 0: - t_flash_all = df[df["Name"].str.contains("void flash")]["Duration (ns)"].to_numpy() + t_flash_all = df[df["Name"].str.contains("flash")]["Duration (ns)"].to_numpy() t_flash_all = t_flash_all.reshape(-1, per_flash) t_flash_avg = np.average(t_flash_all, axis=0) df_times.loc[row, "FlashAttention Kernels (fwd)"] = t_flash_avg[0] / 1e6 diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index 81c871de46..1cac385c6c 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -1.10.0 +1.11.0 diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index f71cef08ea..af11ada34c 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -106,8 +106,12 @@ def run(self) -> None: if isinstance(ext, CMakeExtension): print(f"Building CMake extension {ext.name}") # Set up incremental builds for CMake extensions - setup_dir = Path(__file__).resolve().parent.parent - build_dir = setup_dir / "build" / "cmake" + build_dir = os.getenv("NVTE_CMAKE_BUILD_DIR") + if build_dir: + build_dir = Path(build_dir).resolve() + else: + root_dir = Path(__file__).resolve().parent.parent + build_dir = root_dir / "build" / "cmake" # Ensure the directory exists build_dir.mkdir(parents=True, exist_ok=True) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index f932f0695e..4563a0272a 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -10,8 +10,9 @@ from .utils import ( all_files_in_dir, - cuda_version, + cuda_archs, cuda_path, + cuda_version, ) @@ -48,8 +49,6 @@ def setup_pytorch_extension( ] nvcc_flags = [ "-O3", - "-gencode", - "arch=compute_70,code=sm_70", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_BFLOAT16_OPERATORS__", @@ -61,6 +60,11 @@ def setup_pytorch_extension( "--use_fast_math", ] + cuda_architectures = cuda_archs() + + if "70" in cuda_architectures: + nvcc_flags.extend(["-gencode", "arch=compute_70,code=sm_70"]) + # Version-dependent CUDA options try: version = cuda_version() @@ -73,17 +77,18 @@ def setup_pytorch_extension( ( "--threads", os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"), - "-gencode", - "arch=compute_80,code=sm_80", - "-gencode", - "arch=compute_90,code=sm_90", ) ) + if "80" in cuda_architectures: + nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"]) + if "90" in cuda_architectures: + nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"]) + # Libraries library_dirs = [] libraries = [] - if os.getenv("NVTE_UB_WITH_MPI"): + if bool(int(os.getenv("NVTE_UB_WITH_MPI", 0))): assert ( os.getenv("MPI_HOME") is not None ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" diff --git a/build_tools/utils.py b/build_tools/utils.py index 27ceea844b..d846b87f22 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -6,12 +6,12 @@ import functools import glob +import importlib import os import re import shutil import subprocess import sys -import importlib from pathlib import Path from subprocess import CalledProcessError from typing import List, Optional, Tuple, Union @@ -188,6 +188,11 @@ def cuda_path() -> Tuple[str, str]: return cuda_home, nvcc_bin +@functools.lru_cache(maxsize=None) +def cuda_archs() -> str: + return os.getenv("NVTE_CUDA_ARCHS", "70;80;89;90") + + def cuda_version() -> Tuple[int, ...]: """CUDA Toolkit version as a (major, minor) tuple.""" # Query NVCC for version info diff --git a/docs/_templates/layout.html b/docs/_templates/layout.html index a68b4531e3..f94e526f57 100644 --- a/docs/_templates/layout.html +++ b/docs/_templates/layout.html @@ -1,4 +1,11 @@ {% extends "!layout.html" %} + + {% block extrahead %} + + + + {% endblock %} + {% block sidebartitle %} {{ super() }} - {%- if nvidia_analytics_id %} - - {%- endif %} + {% endblock %} + + {% block footer %} + + {% endblock %} diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index c9504c20af..b097f14475 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -9,6 +9,9 @@ pyTorch .. autoapiclass:: transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs) :members: forward, set_tensor_parallel_group +.. autoapiclass:: transformer_engine.pytorch.GroupedLinear(in_features, out_features, bias=True, **kwargs) + :members: forward, set_tensor_parallel_group + .. autoapiclass:: transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs) .. autoapiclass:: transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs) @@ -44,3 +47,7 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.make_graphed_callables .. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context + +.. autoapifunction:: transformer_engine.pytorch.moe_permute + +.. autoapifunction:: transformer_engine.pytorch.moe_unpermute diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py index cb384aa10c..4413bdfd00 100644 --- a/docs/examples/te_llama/te_llama.py +++ b/docs/examples/te_llama/te_llama.py @@ -102,8 +102,11 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k Custom method adapted from `from_pretrained` method in HuggingFace Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 """ - vanilla_model = cls(config).to(kwargs["torch_dtype"]) - is_local = os.path.isdir(pretrained_model_name_or_path) + # Before loading the model, set the default dtype for torch + torch.set_default_dtype(kwargs["torch_dtype"]) + + # Load the vanilla model weights + vanilla_model = cls(config) subfolder = "" variant = None if os.path.isfile( @@ -133,7 +136,7 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k else: raise AssertionError("Only sharded PyTorch ckpt format supported at the moment") - resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + resolved_archive_file, _ = get_checkpoint_shard_files( pretrained_model_name_or_path, archive_file, ) diff --git a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb index 57c1bf6601..7013e85ec6 100644 --- a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb +++ b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb @@ -247,15 +247,24 @@ "restart_jupyter_notebook()\n", "\n", "\n", - "# Import necessary packages and methods\n", + "# Import necessary packages, methods and variables\n", "from utils import *\n", "\n", "\n", - "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", - "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n", - "# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n", - "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "# Provide Huggingface Access Token\n", + "hyperparams.hf_access_token = \"\"\n", + "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "hyperparams.weights_cache_dir = \"\"\n", + "\n", + "# For Llama 2, uncomment this line (also set by default)\n", + "hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n", + "\n", + "# For Llama 3, uncomment this line\n", + "# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n", + "\n", "hyperparams.mixed_precision = \"bf16\"\n", "\n", "\n", @@ -554,7 +563,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "bdb34b91", "metadata": {}, "outputs": [ @@ -573,15 +582,24 @@ "restart_jupyter_notebook()\n", "\n", "\n", - "# Import necessary packages and methods\n", + "# Import necessary packages, methods and variables\n", "from utils import *\n", "\n", "\n", - "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", - "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n", - "# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n", - "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "# Provide Huggingface Access Token\n", + "hyperparams.hf_access_token = \"\"\n", + "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "hyperparams.weights_cache_dir = \"\"\n", + "\n", + "# For Llama 2, uncomment this line (also set by default)\n", + "hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n", + "\n", + "# For Llama 3, uncomment this line\n", + "# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n", + "\n", "hyperparams.mixed_precision = \"bf16\"\n", "\n", "\n", @@ -653,15 +671,24 @@ "restart_jupyter_notebook()\n", "\n", "\n", - "# Import necessary packages and methods\n", + "# Import necessary packages, methods and variables\n", "from utils import *\n", "\n", "\n", - "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", - "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n", - "# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n", - "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "# Provide Huggingface Access Token\n", + "hyperparams.hf_access_token = \"\"\n", + "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "hyperparams.weights_cache_dir = \"\"\n", + "\n", + "# For Llama 2, uncomment this line (also set by default)\n", + "hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n", + "\n", + "# For Llama 3, uncomment this line\n", + "# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n", + "\n", "hyperparams.mixed_precision = \"fp8\"\n", "\n", "\n", diff --git a/docs/examples/te_llama/utils.py b/docs/examples/te_llama/utils.py index b6b3683d4c..1aebe13afb 100644 --- a/docs/examples/te_llama/utils.py +++ b/docs/examples/te_llama/utils.py @@ -25,7 +25,10 @@ class HyperParameters: def __init__(self): self.mixed_precision = "bf16" - # self.model_name = "" # <== Add model weight location here + + # Set to Meta Llama 2 by default. + self.model_name = "meta-llama/Llama-2-7b-hf" + self.dataset_name = "timdettmers/openassistant-guanaco" self.dataset_text_field = "text" self.learning_rate = 1.41e-5 @@ -35,6 +38,10 @@ def __init__(self): self.num_warmup_steps = 5 self.num_training_steps = 10 + # This is either provided by the user or it will be set when the + # model weights are downloaded. + self.weights_cache_dir = "" + hyperparams = HyperParameters() @@ -76,13 +83,49 @@ def tokenize(element): return train_dataloader +def ensure_model_is_downloaded(hyperparams): + assert hyperparams.model_name in [ + "meta-llama/Meta-Llama-3-8B", + "meta-llama/Llama-2-7b-hf", + ], "Only Meta Llama 2 7B and Meta Llama 3 8B models are supported!" + + # Login using Huggingface Hub API + from huggingface_hub import login + + try: + login(hyperparams.hf_access_token) + except Exception as e: + if "Invalid token passed!" in str(e): + print( + "Please pass a valid HF Access Token! More info at" + " https://huggingface.co/docs/hub/en/security-tokens." + ) + else: + print(f"Exception is {e}") + + # Download the model if it doesn't exist + from huggingface_hub import snapshot_download + + supplied_cache_dir = ( + hyperparams.weights_cache_dir if hyperparams.weights_cache_dir != "" else None + ) + hyperparams.weights_cache_dir = snapshot_download( + repo_id=hyperparams.model_name, cache_dir=supplied_cache_dir + ) + + print(f"Model cache directory : {hyperparams.weights_cache_dir}") + + def init_baseline_model(hyperparams): + # Download and cache the weights + ensure_model_is_downloaded(hyperparams) + # Init the model - config = AutoConfig.from_pretrained(hyperparams.model_name) + config = AutoConfig.from_pretrained(hyperparams.weights_cache_dir) # make sure to use flash_attention to do iso comparison with TELlamaModel config._attn_implementation = "flash_attention_2" model = AutoModelForCausalLM.from_pretrained( - hyperparams.model_name, + hyperparams.weights_cache_dir, config=config, torch_dtype=torch.bfloat16, ) @@ -94,13 +137,16 @@ def init_baseline_model(hyperparams): def init_te_llama_model(hyperparams): + # Download and cache the weights + ensure_model_is_downloaded(hyperparams) + # Init the model from te_llama import TELlamaForCausalLM - config = AutoConfig.from_pretrained(hyperparams.model_name) + config = AutoConfig.from_pretrained(hyperparams.weights_cache_dir) config._attn_implementation = "flash_attention_2" model = TELlamaForCausalLM.from_pretrained_local( - hyperparams.model_name, + hyperparams.weights_cache_dir, config=config, torch_dtype=torch.bfloat16, ) diff --git a/docs/installation.rst b/docs/installation.rst index 012f3303cb..9ac0ddf841 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -28,6 +28,17 @@ Transformer Engine library is preinstalled in the PyTorch container in versions on `NVIDIA GPU Cloud `_. +pip - from PyPI +----------------------- + +Transformer Engine can be directly installed from `our PyPI `_, e.g. + +.. code-block:: bash + + pip install transformer_engine[pytorch] + +To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch,paddle]). Transformer Engine ships wheels for the core library as well as the PaddlePaddle extensions. Source distributions are shipped for the JAX and PyTorch extensions. + pip - from GitHub ----------------------- diff --git a/examples/jax/encoder/README.md b/examples/jax/encoder/README.md index f8e150eac6..575f7be6e0 100644 --- a/examples/jax/encoder/README.md +++ b/examples/jax/encoder/README.md @@ -1,6 +1,6 @@ # Basic Transformer Encoder Example with Optional FP8 # -This example uses Transformer Encoder to demonstrate the Transformer Engine usage. And more focus on scaling up training on multiple GPUs. Highly recommend studying the [MNIST example of the Transformer Engine](/examples/jax/mnist) before reading this example. The Transformer Engine is built on top of [Flax](https://github.com/google/flax). Thus, examples use `pjit` to set up multiple GPU training. The basic pjit usage can be referred to [Scale up Flax Modules on multiple devices with pjit](https://flax.readthedocs.io/en/latest/guides/flax_on_pjit.html). +This example uses Transformer Encoder to demonstrate the Transformer Engine usage. And more focus on scaling up training on multiple GPUs. Highly recommend studying the [MNIST example of the Transformer Engine](/examples/jax/mnist) before reading this example. The Transformer Engine is built on top of [Flax](https://github.com/google/flax). Thus, examples use `jit` `in `in_shardings` and `out_shardings` parameters to set up multiple GPU training. The basic parallel jit usage can be referred to [Scale up Flax Modules on multiple devices](https://flax.readthedocs.io/en/latest/guides/flax_on_pjit.html). ## Single GPU ## @@ -31,11 +31,11 @@ python test_single_gpu_encoder.py --use-fp8 4. The next is to create sharding rules, mapping the device axis to the logical axis. The `te.extend_logical_axis_rules` under fp8_autocast will return a list of pairs of the mapping, such as `(('batch', 'data'), ...)`. The first is the logical axis and second is the device axis. -5. Refer structure of `abs_var_collect['params']` and `abs_var_collect['params_axes']` to set up `PartitionSpec` for pjit. All logical axes should be replaced by device axes. If the value of PartitionSpec is None, that means no sharding, broadcasting the data to every device. Note that the `params_axes` attribute is provided by Transformer Engine. The Flax's module doesn't have it, such as `nn.Embed`. For nn.Embed, assigning an empty PartitionSpec is fine because each device has its own embedding layer in DP mode. The `get_params_pspec` routine is used for this purpose. Because each device has a complete model in DP mode, all values of PartitionSpec in params_pspec should be None. This will be different in the model parallelism example. +5. Refer structure of `abs_var_collect['params']` and `abs_var_collect['params_axes']` to set up `PartitionSpec` for parallel jit. All logical axes should be replaced by device axes. If the value of PartitionSpec is None, that means no sharding, broadcasting the data to every device. Note that the `params_axes` attribute is provided by Transformer Engine. The Flax's module doesn't have it, such as `nn.Embed`. For nn.Embed, assigning an empty PartitionSpec is fine because each device has its own embedding layer in DP mode. The `get_params_pspec` routine is used for this purpose. Because each device has a complete model in DP mode, all values of PartitionSpec in params_pspec should be None. This will be different in the model parallelism example. -6. Fill in `params_pspec` and `encoder.init` to pjit to get a compiled function, `pjit_encoder_init `, and use it to initialize the model, so JAX now can know how to do the sharding. +6. Fill in `params_sharding` and `encoder.init` to jit to get a compiled function, `jit_encoder_init `, and use it to initialize the model, so JAX now can know how to do the sharding. -7. The `train_step` and `eval_step` also need to be compiled by pjit. Thus, every input and output argument has to be set up `PartitionSpec` if the argument contains a tensor. For instance, the `input_pspec` is `PartitionSpec('data', None)` because the input shape is (batch size, sequence length). Then, the rest of the workflow is similar to the previous example. +7. The `train_step` and `eval_step` also need to be compiled by jit. Thus, every input and output argument has to be set up `PartitionSpec` if the argument contains a tensor. For instance, the `input_pspec` is `PartitionSpec('data', None)` because the input shape is (batch size, sequence length). Then, the rest of the workflow is similar to the previous example. 8. Use `CUDA_VISIBLE_DEVICES` to control the number of GPUs used. For example, if the system has 8 GPUs but only 4 GPUs need to be used, then: ```sh @@ -84,7 +84,7 @@ python test_model_parallel_encoder.py --use-fp8 1. This example inherits previous model parallelism example, but uses multiprocessing instead of single-program multiple-data (SPMD). It uses 1 GPU per process. -2. The benefit of multiprocessing is to setup hardware affinity for GPUs, such as NUMA binding. It may help improve performance and stability. Please refer to [Best Practices When Benchmarking CUDA Applications](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2019-s9956/) for more details. +2. There is two main benefits of multiprocessing: support multi-node and to setup hardware affinity for GPUs, such as NUMA binding. Affinity may help improve performance and stability. Please refer to [Best Practices When Benchmarking CUDA Applications](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2019-s9956/) for more details. 3. The quick way to check system topology is to use `nvidia-smi`, for example: ```sh diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 646d6e0a12..25d744887e 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -17,7 +17,7 @@ from flax.linen import partitioning as nn_partitioning from flax.training import train_state from jax.experimental import mesh_utils -from jax.experimental.pjit import pjit +from jax.sharding import PartitionSpec, NamedSharding import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax @@ -223,32 +223,36 @@ def check_fp8(state, var_collect, inputs, masks, labels): ) -def get_params_pspec(sharding_rules, abs_var_collect): - """Refer params to create params partition spec""" - rules_dict = {} - for key, value in sharding_rules: - rules_dict[key] = value +def get_params_sharding(sharding_rules, abs_var_collect, mesh): + """Refer params to create params sharding""" + rules_dict = dict(sharding_rules) def to_device_axis(logical_axis): partitions = [rules_dict[key] for key in logical_axis] - return jax.sharding.PartitionSpec(*partitions) + return NamedSharding(mesh, PartitionSpec(*partitions)) params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {}) - params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes)) - params_axes_pspec = flax.core.unfreeze(params_axes_pspec) - params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY]) - params_pspec = {**params_pspec, **params_axes_pspec} - return params_pspec + params_axes_sharding = jax.tree_util.tree_map( + to_device_axis, nn_partitioning.get_axis_names(params_axes) + ) + params_axes_sharding = flax.core.unfreeze(params_axes_sharding) + params_sharding = jax.tree_util.tree_map( + lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY] + ) + params_sharding = {**params_sharding, **params_axes_sharding} + return params_sharding -def get_state_pspec(state, params_pspec): - """Refer params_pspec to create state partition spec""" +def get_state_sharding(state, params_sharding): + """Refer params_sharding to create state sharding""" def replace_params(x): - return params_pspec if isinstance(x, dict) else None + return params_sharding if isinstance(x, dict) else None - state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, dict)) - return state_pspec + state_sharding = jax.tree_util.tree_map( + replace_params, state, is_leaf=lambda x: isinstance(x, dict) + ) + return state_sharding def train_and_evaluate(args): @@ -270,7 +274,9 @@ def train_and_evaluate(args): ), f"Test batch size needs to be multiple of {num_gpu_dp}" device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) - with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)): + with jax.sharding.Mesh( + devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) + ) as mesh: rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) @@ -291,34 +297,39 @@ def train_and_evaluate(args): customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules - params_pspec = get_params_pspec(sharding_rules, abs_var_collect) - inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) - masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) + params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh) + inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None)) + masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None)) - in_shardings = (None, inputs_pspec, masks_pspec) + in_shardings = (None, inputs_sharding, masks_sharding) out_shardings = { - key: params_pspec if key is PARAMS_KEY else None for key in abs_var_collect + key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect } - pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings) - var_collect = pjit_encoder_init(init_rngs, inputs, masks) + jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings) + var_collect = jit_encoder_init(init_rngs, inputs, masks) optimizer = optax.adamw(args.lr) var_collect, params = flax.core.pop(var_collect, PARAMS_KEY) state = train_state.TrainState.create( apply_fn=encoder.apply, params=params, tx=optimizer ) - state_pspec = get_state_pspec(state, params_pspec) - labels_pspec = jax.sharding.PartitionSpec( - DEVICE_DP_AXIS, + state_sharding = get_state_sharding(state, params_sharding) + labels_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS)) + + in_shardings = ( + state_sharding, + inputs_sharding, + masks_sharding, + labels_sharding, + None, + None, ) + out_shardings = (state_sharding, None, None, None) + jit_train_step = jax.jit(train_step, in_shardings, out_shardings) - in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None) - out_shardings = (state_pspec, None, None, None) - pjit_train_step = pjit(train_step, in_shardings, out_shardings) - - in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None) + in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) out_shardings = (None, None) - pjit_eval_step = pjit(eval_step, in_shardings, out_shardings) + jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) @@ -327,7 +338,7 @@ def train_and_evaluate(args): if args.dry_run: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) rngs = {DROPOUT_KEY: dropout_rng} - pjit_train_step(state, inputs, masks, labels, var_collect, rngs) + jit_train_step(state, inputs, masks, labels, var_collect, rngs) print("PASSED") return None @@ -337,11 +348,11 @@ def train_and_evaluate(args): rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} state, train_loss, train_accuracy, var_collect = train_epoch( - state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step + state, train_ds, args.batch_size, rngs, var_collect, jit_train_step ) test_loss, test_accuracy = eval_model( - state, test_ds, args.test_batch_size, var_collect, pjit_eval_step + state, test_ds, args.test_batch_size, var_collect, jit_eval_step ) print( diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 005ae50e72..9d08254f4d 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -17,7 +17,7 @@ from flax.linen import partitioning as nn_partitioning from flax.training import train_state from jax.experimental import mesh_utils -from jax.experimental.pjit import pjit +from jax.sharding import PartitionSpec, NamedSharding import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax @@ -202,32 +202,36 @@ def check_fp8(state, var_collect, inputs, masks, labels): ) -def get_params_pspec(sharding_rules, abs_var_collect): - """Refer params to create params partition spec""" - rules_dict = {} - for key, value in sharding_rules: - rules_dict[key] = value +def get_params_sharding(sharding_rules, abs_var_collect, mesh): + """Refer params to create params sharding""" + rules_dict = dict(sharding_rules) def to_device_axis(logical_axis): partitions = [rules_dict[key] for key in logical_axis] - return jax.sharding.PartitionSpec(*partitions) + return NamedSharding(mesh, PartitionSpec(*partitions)) params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {}) - params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes)) - params_axes_pspec = flax.core.unfreeze(params_axes_pspec) - params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY]) - params_pspec = {**params_pspec, **params_axes_pspec} - return params_pspec + params_axes_sharding = jax.tree_util.tree_map( + to_device_axis, nn_partitioning.get_axis_names(params_axes) + ) + params_axes_sharding = flax.core.unfreeze(params_axes_sharding) + params_sharding = jax.tree_util.tree_map( + lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY] + ) + params_sharding = {**params_sharding, **params_axes_sharding} + return params_sharding -def get_state_pspec(state, params_pspec): - """Refer params_pspec to create state partition spec""" +def get_state_sharding(state, params_sharding): + """Refer params_sharding to create state sharding""" def replace_params(x): - return params_pspec if isinstance(x, dict) else None + return params_sharding if isinstance(x, dict) else None - state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, dict)) - return state_pspec + state_sharding = jax.tree_util.tree_map( + replace_params, state, is_leaf=lambda x: isinstance(x, dict) + ) + return state_sharding def train_and_evaluate(args): @@ -240,7 +244,7 @@ def train_and_evaluate(args): assert args.test_batch_size % num_gpu == 0, f"Test batch size needs to be multiple of {num_gpu}" device_mesh = mesh_utils.create_device_mesh((num_gpu,)) - with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)): + with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)) as mesh: rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) @@ -260,34 +264,43 @@ def train_and_evaluate(args): abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) sharding_rules = te_flax.extend_logical_axis_rules(tuple()) - params_pspec = get_params_pspec(sharding_rules, abs_var_collect) - inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) - masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) + params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh) + inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None)) + masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None)) - in_shardings = (None, inputs_pspec, masks_pspec) + in_shardings = (None, inputs_sharding, masks_sharding) out_shardings = { - key: params_pspec if key is PARAMS_KEY else None for key in abs_var_collect + key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect } - pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings) - var_collect = pjit_encoder_init(init_rngs, inputs, masks) + jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings) + var_collect = jit_encoder_init(init_rngs, inputs, masks) optimizer = optax.adamw(args.lr) var_collect, params = flax.core.pop(var_collect, PARAMS_KEY) state = train_state.TrainState.create( apply_fn=encoder.apply, params=params, tx=optimizer ) - state_pspec = get_state_pspec(state, params_pspec) - labels_pspec = jax.sharding.PartitionSpec( - DEVICE_DP_AXIS, + state_sharding = get_state_sharding(state, params_sharding) + labels_sharding = NamedSharding( + mesh, + PartitionSpec( + DEVICE_DP_AXIS, + ), ) + in_shardings = ( + state_sharding, + inputs_sharding, + masks_sharding, + labels_sharding, + None, + None, + ) + out_shardings = (state_sharding, None, None, None) + jit_train_step = jax.jit(train_step, in_shardings, out_shardings) - in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None) - out_shardings = (state_pspec, None, None, None) - pjit_train_step = pjit(train_step, in_shardings, out_shardings) - - in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None) + in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) out_shardings = (None, None) - pjit_eval_step = pjit(eval_step, in_shardings, out_shardings) + jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) @@ -296,7 +309,7 @@ def train_and_evaluate(args): if args.dry_run: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) rngs = {DROPOUT_KEY: dropout_rng} - pjit_train_step(state, inputs, masks, labels, var_collect, rngs) + jit_train_step(state, inputs, masks, labels, var_collect, rngs) print("PASSED") return None @@ -306,11 +319,11 @@ def train_and_evaluate(args): rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} state, train_loss, train_accuracy, var_collect = train_epoch( - state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step + state, train_ds, args.batch_size, rngs, var_collect, jit_train_step ) test_loss, test_accuracy = eval_model( - state, test_ds, args.test_batch_size, var_collect, pjit_eval_step + state, test_ds, args.test_batch_size, var_collect, jit_eval_step ) print( diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 286c064e96..e581dbc3f9 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -19,7 +19,7 @@ from flax.linen import partitioning as nn_partitioning from flax.training import train_state from jax.experimental import mesh_utils -from jax.experimental.pjit import pjit +from jax.sharding import PartitionSpec, NamedSharding import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax @@ -305,32 +305,36 @@ def check_fp8(state, var_collect, inputs, masks, labels): ) -def get_params_pspec(sharding_rules, abs_var_collect): - """Refer params to create params partition spec""" - rules_dict = {} - for key, value in sharding_rules: - rules_dict[key] = value +def get_params_sharding(sharding_rules, abs_var_collect, mesh): + """Refer params to create params sharding""" + rules_dict = dict(sharding_rules) def to_device_axis(logical_axis): partitions = [rules_dict[key] for key in logical_axis] - return jax.sharding.PartitionSpec(*partitions) + return NamedSharding(mesh, jax.sharding.PartitionSpec(*partitions)) params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {}) - params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes)) - params_axes_pspec = flax.core.unfreeze(params_axes_pspec) - params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY]) - params_pspec = {**params_pspec, **params_axes_pspec} - return params_pspec + params_axes_sharding = jax.tree_util.tree_map( + to_device_axis, nn_partitioning.get_axis_names(params_axes) + ) + params_axes_sharding = flax.core.unfreeze(params_axes_sharding) + params_sharding = jax.tree_util.tree_map( + lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY] + ) + params_sharding = {**params_sharding, **params_axes_sharding} + return params_sharding -def get_state_pspec(state, params_pspec): - """Refer params_pspec to create state partition spec""" +def get_state_sharding(state, params_sharding): + """Refer params_sharding to create state sharding""" def replace_params(x): - return params_pspec if isinstance(x, dict) else None + return params_sharding if isinstance(x, dict) else None - state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, dict)) - return state_pspec + state_sharding = jax.tree_util.tree_map( + replace_params, state, is_leaf=lambda x: isinstance(x, dict) + ) + return state_sharding def train_and_evaluate(args): @@ -362,7 +366,7 @@ def train_and_evaluate(args): device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) with jax.sharding.Mesh( devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) - ) as shard_mesh: + ) as mesh: rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) @@ -383,34 +387,41 @@ def train_and_evaluate(args): customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules - params_pspec = get_params_pspec(sharding_rules, abs_var_collect) + params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh) inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) - in_shardings = (None, inputs_pspec, masks_pspec) + inputs_sharding = NamedSharding(mesh, inputs_pspec) + masks_sharding = NamedSharding(mesh, masks_pspec) + in_shardings = (None, inputs_sharding, masks_sharding) out_shardings = { - key: params_pspec if key is PARAMS_KEY else None for key in abs_var_collect + key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect } - pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings) - var_collect = pjit_encoder_init(init_rngs, inputs, masks) + jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings) + var_collect = jit_encoder_init(init_rngs, inputs, masks) optimizer = optax.adamw(args.lr) var_collect, params = flax.core.pop(var_collect, PARAMS_KEY) state = train_state.TrainState.create( apply_fn=encoder.apply, params=params, tx=optimizer ) - state_pspec = get_state_pspec(state, params_pspec) - labels_pspec = jax.sharding.PartitionSpec( - DEVICE_DP_AXIS, + state_sharding = get_state_sharding(state, params_sharding) + labels_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS)) + + in_shardings = ( + state_sharding, + inputs_sharding, + masks_sharding, + labels_sharding, + None, + None, ) + out_shardings = (state_sharding, None, None, None) + jit_train_step = jax.jit(train_step, in_shardings, out_shardings) - in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None) - out_shardings = (state_pspec, None, None, None) - pjit_train_step = pjit(train_step, in_shardings, out_shardings) - - in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None) + in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) out_shardings = (None, None) - pjit_eval_step = pjit(eval_step, in_shardings, out_shardings) + jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) @@ -419,7 +430,7 @@ def train_and_evaluate(args): if args.dry_run: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) rngs = {DROPOUT_KEY: dropout_rng} - pjit_train_step(state, inputs, masks, labels, var_collect, rngs) + jit_train_step(state, inputs, masks, labels, var_collect, rngs) print("PASSED") else: for epoch in range(1, args.epochs + 1): @@ -433,11 +444,11 @@ def train_and_evaluate(args): args.batch_size, rngs, var_collect, - pjit_train_step, - shard_mesh, + jit_train_step, + mesh, inputs_pspec, masks_pspec, - labels_pspec, + labels_sharding.spec, ) test_loss, test_accuracy = eval_model( @@ -445,11 +456,11 @@ def train_and_evaluate(args): test_ds, args.test_batch_size, var_collect, - pjit_eval_step, - shard_mesh, + jit_eval_step, + mesh, inputs_pspec, masks_pspec, - labels_pspec, + labels_sharding.spec, ) if args.process_id == 0: print( diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index e6ccf3b82f..b69aed6648 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -23,3 +23,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops_distributed.py +pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py diff --git a/setup.py b/setup.py index 942f57d3c1..0b0639aea6 100644 --- a/setup.py +++ b/setup.py @@ -13,18 +13,17 @@ from wheel.bdist_wheel import bdist_wheel from build_tools.build_ext import CMakeExtension, get_build_ext +from build_tools.te_version import te_version from build_tools.utils import ( + cuda_archs, found_cmake, found_ninja, found_pybind11, - remove_dups, get_frameworks, install_and_import, remove_dups, uninstall_te_wheel_packages, ) -from build_tools.te_version import te_version - frameworks = get_frameworks() current_file_path = Path(__file__).parent.resolve() @@ -60,10 +59,11 @@ def setup_common_extension() -> CMakeExtension: """Setup CMake extension for common library""" # Project directory root root_path = Path(__file__).resolve().parent + return CMakeExtension( name="transformer_engine", cmake_path=root_path / Path("transformer_engine/common"), - cmake_flags=[], + cmake_flags=["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())], ) @@ -93,7 +93,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: # Framework-specific requirements if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "pytorch" in frameworks: - install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.5.8,!=2.0.9,!=2.1.0"]) + install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.6.3,!=2.0.9,!=2.1.0"]) test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"]) if "jax" in frameworks: install_reqs.extend(["jax", "flax>=0.7.1"]) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index e590d8e92a..45806e7022 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -13,6 +13,7 @@ add_executable(test_operator test_layernorm.cu test_rmsnorm.cu test_multi_cast_transpose.cu + test_multi_padding.cu test_causal_softmax.cu ../test_common.cu) diff --git a/tests/cpp/operator/test_cast_transpose.cu b/tests/cpp/operator/test_cast_transpose.cu index 8c168c76f4..39a6614179 100644 --- a/tests/cpp/operator/test_cast_transpose.cu +++ b/tests/cpp/operator/test_cast_transpose.cu @@ -69,6 +69,8 @@ void performTest(const size_t N, const size_t H) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output_c.scale(); + compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); diff --git a/tests/cpp/operator/test_cast_transpose_dbias.cu b/tests/cpp/operator/test_cast_transpose_dbias.cu index a2c8594730..651508c871 100644 --- a/tests/cpp/operator/test_cast_transpose_dbias.cu +++ b/tests/cpp/operator/test_cast_transpose_dbias.cu @@ -116,6 +116,8 @@ void performTest(const size_t N, const size_t H) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output_c.scale(); + compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); diff --git a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu index 5920f23f38..38ac955bc9 100644 --- a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu +++ b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu @@ -132,6 +132,8 @@ void performTest(const size_t N, const size_t H) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output_c.scale(); + compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); diff --git a/tests/cpp/operator/test_layernorm.cu b/tests/cpp/operator/test_layernorm.cu index 07a8a8884c..cdd8e7846c 100644 --- a/tests/cpp/operator/test_layernorm.cu +++ b/tests/cpp/operator/test_layernorm.cu @@ -230,6 +230,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma) auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); if (isFp8Type(otype)) { compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / z.scale(); + compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); diff --git a/tests/cpp/operator/test_multi_cast_transpose.cu b/tests/cpp/operator/test_multi_cast_transpose.cu index 85fc3a573a..e7fb183217 100644 --- a/tests/cpp/operator/test_multi_cast_transpose.cu +++ b/tests/cpp/operator/test_multi_cast_transpose.cu @@ -139,6 +139,10 @@ void performTest() { output_c_list[tensor_id].amax(), ref_amax_list[tensor_id], atol_amax, rtol_amax); + compareResults("scale_inv", + output_c_list[tensor_id].scale_inv(), + 1.f / output_c_list[tensor_id].scale(), + atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); compareResults("output_c", diff --git a/tests/cpp/operator/test_multi_padding.cu b/tests/cpp/operator/test_multi_padding.cu new file mode 100644 index 0000000000..e9e42725fe --- /dev/null +++ b/tests/cpp/operator/test_multi_padding.cu @@ -0,0 +1,169 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +template +void compute_ref(const std::vector>& input_list, + std::vector>& output_list, + const std::vector& height_list, + const std::vector& width_list, + const std::vector& padded_height_list) { + using compute_t = float; + for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { + const auto& input = input_list[tensor_id]; + auto& output = output_list[tensor_id]; + const size_t height = height_list[tensor_id]; + const size_t width = width_list[tensor_id]; + const size_t padded_height = padded_height_list[tensor_id]; + + for (size_t i = 0; i < padded_height; ++i) { + if (i < height) { + for (size_t j = 0; j < width; ++j) { + const compute_t x = static_cast(input[i * width + j]); + const OutputType y = static_cast(x); + output[i * width + j] = y; + } + } else { + for (size_t j = 0; j < width; ++j) { + output[i * width + j] = static_cast(0.f); + } + } + } + } +} + +template +void performTest() { + using namespace test; + + const DType itype = TypeInfo::dtype; + const DType otype = TypeInfo::dtype; + const std::vector> tensor_dims = {{1,1}, + {1,768}, + {768,1}, + {768,768}, + {43,43}, + {43,256}, + {256,43}, + {256,256}}; + const size_t num_tensors = tensor_dims.size(); + constexpr int align = 16; + + // Buffers for Transformer Engine implementation + std::vector input_list, output_list, output_t_list; + + // Buffers for reference implementation + std::vector> ref_input_list; + std::vector> ref_output_list; + std::vector ref_height_list(num_tensors), ref_width_list(num_tensors); + std::vector ref_padded_height_list(num_tensors); + + // Initialize buffers + for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { + const size_t height = tensor_dims[tensor_id].first; + const size_t width = tensor_dims[tensor_id].second; + const size_t padded_height = (height + align - 1) / align * align; + input_list.emplace_back(Tensor({ height, width }, itype)); + output_list.emplace_back(Tensor({ padded_height, width }, otype)); + + auto& input = input_list.back(); + auto& output = output_list.back(); + fillUniform(&input); + setRandomScale(&output); + + ref_input_list.emplace_back(height*width); + ref_output_list.emplace_back(padded_height*width); + + std::copy(input.cpu_dptr(), + input.cpu_dptr() + height * width, + ref_input_list.back().begin()); + ref_height_list[tensor_id] = height; + ref_width_list[tensor_id] = width; + ref_padded_height_list[tensor_id] = padded_height; + } + + // Transformer Engine implementation + auto make_nvte_vector = [](std::vector& tensor_list) + -> std::vector { + std::vector nvte_tensor_list; + for (auto& tensor : tensor_list) { + nvte_tensor_list.emplace_back(tensor.data()); + } + return nvte_tensor_list; + }; + nvte_multi_padding(num_tensors, + make_nvte_vector(input_list).data(), + make_nvte_vector(output_list).data(), + ref_padded_height_list.data(), + 0); + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + // Reference implementation + compute_ref(ref_input_list, + ref_output_list, + ref_height_list, + ref_width_list, + ref_padded_height_list); + + // Check correctness + for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { + auto [atol, rtol] = getTolerances(otype); + compareResults("output", + output_list[tensor_id], + ref_output_list[tensor_id].data(), + atol, rtol); + } +} + +} // namespace + +class MultiPaddingTestSuite + : public ::testing::TestWithParam< + transformer_engine::DType> {}; + +TEST_P(MultiPaddingTestSuite, TestMultiPaddingTranspose) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = GetParam(); + const DType output_type = input_type; + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performTest(); + ); + ); +} + + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + MultiPaddingTestSuite, + ::testing::ValuesIn(test::all_fp_types), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(info.param); + return name; + }); diff --git a/tests/cpp/operator/test_rmsnorm.cu b/tests/cpp/operator/test_rmsnorm.cu index e4e34bac8a..0ec3a877e5 100644 --- a/tests/cpp/operator/test_rmsnorm.cu +++ b/tests/cpp/operator/test_rmsnorm.cu @@ -187,6 +187,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma) auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); if (isFp8Type(otype)) { compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / z.scale(); + compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index 3a7fe33378..bbd54ecce5 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -4,6 +4,8 @@ import operator import re from functools import reduce +from itertools import product +import pytest import jax from jax.experimental.pjit import pjit, _UNSPECIFIED @@ -29,6 +31,28 @@ def generate_configs(): return configs +def generate_context_parallel_configs(): + configs = [] + + DP_sizes = (1, 2) + CP_sizes = (1, 2, 4, 8) + TP_sizes = (1, 2) + for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes): + ndev = cp * tp * dp + if is_devices_enough(ndev): + configs.append( + pytest.param( + ndev, + (dp, cp, tp), + ("dp", "cp", "tp"), + MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp"), + id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}", + ) + ) + + return configs + + COLL_AR_KEY = "all-reduce" COLL_AG_KEY = "all-gather" COLL_OTHER_KEY = "other" diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 15676dd270..61d68aacae 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -3,6 +3,7 @@ # See LICENSE for license information. import pytest +from functools import partial import jax import jax.numpy as jnp @@ -10,8 +11,13 @@ from flax.linen import dot_product_attention from jax import random from jax.sharding import Mesh, NamedSharding, PartitionSpec -from distributed_test_base import generate_configs, generate_collectives_count, compare_ops -from utils import make_causal_mask, make_self_mask +from distributed_test_base import ( + generate_configs, + generate_context_parallel_configs, + generate_collectives_count, + compare_ops, +) +from utils import make_causal_mask, make_self_mask, assert_tree_like_allclose, assert_allclose from transformer_engine.jax import fp8_autocast from transformer_engine.jax.attention import ( is_fused_attn_kernel_available, @@ -19,6 +25,10 @@ AttnBiasType, AttnMaskType, QKVLayout, + QKVFormat, + get_qkv_format, + reorder_causal_load_balancing, + inverse_reorder_causal_load_balancing, ) @@ -263,7 +273,8 @@ def target_func(q, kv, mask): scaling_factor=scaling_factor, dropout_probability=dropout_prob, is_training=is_training, - ) + ), + dtype=jnp.float32, ) def ref_func(query, kv, mask): @@ -284,7 +295,7 @@ def ref_func(query, kv, mask): dtype=jnp.float32, ) - return jnp.mean(output).astype(dtype) + return jnp.mean(output, dtype=jnp.float32) (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs( data_shape, mesh_resource, attn_mask_type, dtype @@ -310,3 +321,229 @@ def ref_func(query, kv, mask): in_shardings=(q_pspec, kv_pspec, mask_pspec), out_shardings=(None, (q_pspec, kv_pspec)), ) + + +class TestDistributedContexParallelSelfAttn: + + def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype): + batch, seqlen, heads, hidden = shape + qkey, kkey, vkey = random.split(random.PRNGKey(1124), 3) + q = random.normal(qkey, shape, dtype=dtype) + k = random.normal(kkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype) + v = random.normal(vkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype) + + mask = None + if attn_mask_type == AttnMaskType.CAUSAL_MASK: + mask = make_causal_mask(batch, seqlen) + + return q, k, v, mask + + def qkv_to_layout(self, q, k, v, qkv_layout): + qkv_args = () + match qkv_layout: + case QKVLayout.BSHD_BS2HD: + k, v = map(partial(jnp.expand_dims, axis=-3), [k, v]) + kv = jnp.concatenate((k, v), axis=-3) + qkv_args = (q, kv) + case QKVLayout.BSHD_BSHD_BSHD: + qkv_args = (q, k, v) + case _: + raise ValueError(f"Unsupported {qkv_layout=}") + return qkv_args + + @pytest.mark.parametrize( + "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() + ) + @pytest.mark.parametrize( + "data_shape", + [ + pytest.param([2, 512, 12, 128], id="2-512-12-128"), + pytest.param([4, 1024, 16, 64], id="4-1024-16-64"), + ], + ) + @pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16]) + @pytest.mark.parametrize( + "attn_mask_type", + [ + pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"), + pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"), + ], + ) + @pytest.mark.parametrize("dtype", [jnp.bfloat16]) + @pytest.mark.parametrize( + "qkv_layout", + [ + pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"), + pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"), + ], + ) + @pytest.mark.parametrize( + "load_balanced", [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")] + ) + def test_contex_parallel_self_attn( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + kv_groups, + attn_mask_type, + dtype, + qkv_layout, + load_balanced, + ): + attn_bias_type = AttnBiasType.NO_BIAS + dropout_prob = 0.0 + is_training = True + scaling_factor = 1.0 + dp_size, cp_size, tp_size = mesh_shape + qkv_format = get_qkv_format(qkv_layout) + + _, seqlen, num_head, hidden = data_shape + num_kv_heads = num_head // kv_groups + + # make sure the mesh evently divides cp and tp axis + if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0: + pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}") + + def target_func(q, k, v, mask): + return jnp.mean( + fused_attn( + self.qkv_to_layout(q, k, v, qkv_layout), + bias=None, + mask=mask, + seed=None, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_prob, + is_training=is_training, + context_parallel_causal_load_balanced=load_balanced, + ), + ).astype(dtype) + + def ref_func(q, k, v, mask, kv_groups): + q = jnp.squeeze(q) + k = jnp.squeeze(jnp.repeat(k, kv_groups, axis=2)) + v = jnp.squeeze(jnp.repeat(v, kv_groups, axis=2)) + output = dot_product_attention( + q, + k, + v, + bias=None, + mask=mask, + deterministic=is_training, + dropout_rate=dropout_prob, + dropout_rng=None, + dtype=jnp.float32, + ) + return jnp.mean(output).astype(dtype) + + q, k, v, mask = self.generate_inputs(data_shape, kv_groups, attn_mask_type, dtype) + + # Single GPU (reference) + ref_func_jit = jax.jit(jax.value_and_grad(ref_func, argnums=[0, 1, 2]), static_argnums=[4]) + ref_fwd, ref_grads = ref_func_jit(q, k, v, mask, kv_groups) + + # Multi GPU (function under test) + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + with mesh, fp8_autocast(mesh_resource=mesh_resource): + qkv_ps = PartitionSpec( + mesh_resource.dp_resource, + mesh_resource.cp_resource, + mesh_resource.tp_resource, + None, + ) + qkv_sharding = NamedSharding(mesh, qkv_ps) + + mask_ps = PartitionSpec( + mesh_resource.dp_resource, None, mesh_resource.cp_resource, None + ) + mask_sharding = NamedSharding(mesh, mask_ps) + + reorder = partial( + reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format + ) + inverse_reorder = partial( + inverse_reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format + ) + + if load_balanced: + q, k, v = jax.tree.map(reorder, (q, k, v)) + + q_, k_, v_ = map(partial(jax.device_put, device=qkv_sharding), [q, k, v]) + mask_ = jax.device_put(mask, device=mask_sharding) + + target_func_jit = jax.jit( + jax.value_and_grad(target_func, argnums=[0, 1, 2]), + in_shardings=[qkv_sharding, qkv_sharding, qkv_sharding, mask_sharding], + out_shardings=(None, (qkv_sharding, qkv_sharding, qkv_sharding)), + ) + + target_fwd, target_grads = target_func_jit(q_, k_, v_, mask_) + + if load_balanced: + target_dq, target_dk, target_dv = jax.tree.map(inverse_reorder, target_grads[0:3]) + target_grads = (target_dq, target_dk, target_dv, *target_grads[3:]) + + def _print_diffs(target, ref): + print("min: ", jnp.min(target), jnp.min(ref)) + print("max: ", jnp.max(target), jnp.max(ref)) + print("mean: ", jnp.mean(target), jnp.mean(ref)) + print("median: ", jnp.median(target), jnp.median(ref)) + print("std: ", jnp.std(target), jnp.std(ref)) + print("var: ", jnp.var(target), jnp.var(ref)) + print("max diff: ", jnp.max(jnp.abs(target - ref))) + + has_diffs = False + + try: + assert_allclose(target_fwd, ref_fwd, dtype=dtype) + except AssertionError as e: + has_diffs = True + print(f"target_fwd v. ref_fwd") + _print_diffs(target_fwd, ref_fwd) + + for i in range(len(target_grads)): + if ref_grads[i] is None or target_grads[i] is None: + # expect both none if one is + assert target_grads[i] is None and ref_grads[i] is None + else: + try: + assert_allclose(target_grads[i], ref_grads[i]) + except AssertionError as e: + has_diffs = True + print(f"target_grads[{i}] v. ref_grads[{i}]") + _print_diffs(target_grads[i], ref_grads[i]) + + assert has_diffs == False, "has_diffs != False" + + +class TestReorderCausalLoadBalancing: + @pytest.mark.parametrize("cp_size", [2, 4, 8]) + @pytest.mark.parametrize( + "shape", + [ + pytest.param([1, 16, 1, 1], id="1-16-1-1"), + pytest.param([4, 32, 12, 32], id="4-32-12-32"), + pytest.param([3, 32, 8, 64], id="3-32-8-64"), + ], + ) + @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD]) + def test(self, cp_size, shape, qkv_format): + tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16) + if qkv_format == QKVFormat.SBHD: + tensor = tensor.swapaxes(0, 1) + + ref = tensor.copy() + + reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2]) + inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2]) + + reordered = reorder(tensor, cp_size, qkv_format) + inversed = inverse(reordered, cp_size, qkv_format) + + assert jnp.array_equal(inversed, ref) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 390a3e2c4e..2709eeab30 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -29,7 +29,10 @@ get_qkv_format, ) from transformer_engine.jax.cpp_extensions import FusedAttnHelper -from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend +from transformer_engine.transformer_engine_jax import ( + NVTE_Fused_Attn_Backend, + get_cudnn_version, +) from utils import assert_allclose @@ -230,7 +233,14 @@ def customcall_fused_dpa( kwargs.pop("max_segments_per_seq") return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype) return fused_attn_thd( - qkv_args, bias, seqlens_q, seqlens_kv, offsets_q, offsets_kv, dropout_rng, **kwargs + qkv_args, + bias, + seqlens_q, + seqlens_kv, + offsets_q, + offsets_kv, + dropout_rng, + **kwargs, ).astype(query.dtype) @@ -265,6 +275,15 @@ class FusedAttnRunner: qkv_layout: QKVLayout bias_shape: BiasShape + # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue + # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases. + def _get_max_segments_per_sequence(self): + if 90400 <= get_cudnn_version() < 90500: + return self.num_segments_per_seq + else: + # +1 for testing runtime_segments < max_segments + return self.num_segments_per_seq + 1 + def _check_configs(self): # TODO(rewang): probably adds this in is_fused_attn_available if get_qkv_format(self.qkv_layout) == QKVFormat.THD and not self.attn_mask_type in [ @@ -299,7 +318,10 @@ def _check_configs(self): self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS and self.bias_shape != BiasShape.BIAS_1HSS ): - if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: + if self.attn_mask_type not in [ + AttnMaskType.NO_MASK, + AttnMaskType.CAUSAL_MASK, + ]: pytest.skip( "B1SS, BHSS and 11SS bias shapes are only supported for " "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK." @@ -316,7 +338,12 @@ def _setup_inputs(self): q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5) q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim) - k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim) + k_shape = v_shape = ( + self.batch_size, + self.max_seqlen_kv, + self.num_heads_kv, + self.head_dim, + ) if self.attn_bias_type == AttnBiasType.NO_BIAS: bias_shape = None @@ -325,7 +352,12 @@ def _setup_inputs(self): elif self.bias_shape == BiasShape.BIAS_B1SS: bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv) elif self.bias_shape == BiasShape.BIAS_BHSS: - bias_shape = (self.batch_size, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv) + bias_shape = ( + self.batch_size, + self.num_heads_q, + self.max_seqlen_q, + self.max_seqlen_kv, + ) elif self.bias_shape == BiasShape.BIAS_11SS: bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv) else: @@ -405,7 +437,10 @@ def generate_random_segment_ids( self.segment_pad_kv = self.segment_pad_q else: self.token_kv, self.segment_pad_kv = generate_random_segment_ids( - self.batch_size, self.max_seqlen_kv, self.num_segments_per_seq, seed=2024 + self.batch_size, + self.max_seqlen_kv, + self.num_segments_per_seq, + seed=2024, ) self.pad_q = self.segment_pad_q self.pad_kv = self.segment_pad_kv @@ -464,8 +499,7 @@ def test_forward(self): "dropout_probability": self.dropout_prob, "is_training": self.is_training, "qkv_layout": self.qkv_layout, - # +1 for testing runtime_segments < max_segments - "max_segments_per_seq": self.num_segments_per_seq + 1, + "max_segments_per_seq": self._get_max_segments_per_sequence(), } # Convert the outputs to float32 for the elementwise comparison @@ -522,7 +556,7 @@ def grad_func(func, *args, **kwargs): "dropout_probability": self.dropout_prob, "is_training": self.is_training, "qkv_layout": self.qkv_layout, - "max_segments_per_seq": self.num_segments_per_seq + 1, + "max_segments_per_seq": self._get_max_segments_per_sequence(), } # We can compute dBias only for the [1, h, s, s] layout @@ -635,7 +669,16 @@ def check_dqkv(primitive, reference, pad): pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"), pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"), pytest.param(4, 512, 128, 16, 16, 64, jnp.bfloat16, id="4-512-128-16-16-64-BF16-CROSS"), - pytest.param(2, 2048, 1024, 12, 12, 64, jnp.bfloat16, id="2-2048-1048-12-12-64-BF16-CROSS"), + pytest.param( + 2, + 2048, + 1024, + 12, + 12, + 64, + jnp.bfloat16, + id="2-2048-1048-12-12-64-BF16-CROSS", + ), pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"), pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"), ], diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 82a3c8576b..d110dece53 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -22,6 +22,7 @@ get_attention_backend, _flash_attn_2_plus, _flash_attn_2_3_plus, + _flash_attn_3_plus, check_set_window_size, AttentionParams, _attention_backends, @@ -135,7 +136,6 @@ def _get_attention_backends( os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_UNFUSED_ATTN"] = "1" - global _attention_backends _attention_backends["backend_selection_requires_update"] = True alibi_slopes_shape = None @@ -233,9 +233,9 @@ def test_dot_product_attention( """Test DotProductAttention module""" # Get configs - tols = dict(atol=5e-3, rtol=5e-3) + tols = dict(atol=1e-3, rtol=1e-3) if dtype == torch.bfloat16: - tols = dict(atol=2.5e-2, rtol=2.5e-2) + tols = dict(atol=1.5e-2, rtol=1.5e-2) config = model_configs[model] is_mla = config.head_dim_qk != config.head_dim_v if qkv_layout is None: @@ -420,6 +420,10 @@ def test_dpa_mla(dtype, model_configs, model): "mask_8_1": ModelConfig( 1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" ), + "mask_9_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"), + "mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"), + "mask_10_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), + "mask_10_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), } @@ -674,7 +678,6 @@ def _run_dot_product_attention( if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0" - global _attention_backends _attention_backends["backend_selection_requires_update"] = True # Create seqlens @@ -1032,7 +1035,7 @@ def test_transformer_layer( # Get configs config = model_configs[model] - tols = dict(atol=5e-1, rtol=5e-2) + tols = dict(atol=5e-2, rtol=5e-2) workspace_opt = True # Test backend availability @@ -1163,7 +1166,6 @@ def _run_transformer_layer( os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" - global _attention_backends _attention_backends["backend_selection_requires_update"] = True # Create input tensor @@ -1301,6 +1303,7 @@ def _run_transformer_layer( "fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"), "fp8_13": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"), "fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), + "fp8_15": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), } param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16] @@ -1312,6 +1315,27 @@ def _rmse(a, b): return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum()) +def _error(a, b, name_a, name_b, atol, rtol, rmse_tol): + logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item())) + logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item())) + try: + torch.testing.assert_close(a, b, atol=atol, rtol=rtol) + except Exception as e: + logging.debug(e) + + rmse = _rmse(a, b) + logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse)) + rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) + assert rmse < rmse_tol * rmse_range, ( + name_a + + " vs " + + name_b + + " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( + rmse, rmse_tol * rmse_range, rmse_tol, rmse_range + ) + ) + + @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") @@ -1320,86 +1344,76 @@ def _rmse(a, b): @pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16) @pytest.mark.parametrize("input_layernorm", [True, False]) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) -def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd): - os.environ["NVTE_FLASH_ATTN"] = "0" - os.environ["NVTE_FUSED_ATTN"] = "1" +@pytest.mark.parametrize("RoPE", [True, False]) +@pytest.mark.parametrize("is_training", [True, False]) +def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training): os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" - global _attention_backends - _attention_backends["backend_selection_requires_update"] = True + os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" config = model_configs_fp8_vs_f16[model] - os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" + if _flash_attn_3_plus and not is_training: + if RoPE: + pytest.skip("Flash Attention doesn't support FP8 MHA with RoPE.") + os.environ["NVTE_FLASH_ATTN"] = "1" + os.environ["NVTE_FUSED_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") + flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16( + dtype, config, True, qkv_format, input_layernorm, RoPE, is_training + ) + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( - dtype, config, True, qkv_format, input_layernorm + dtype, config, True, qkv_format, input_layernorm, RoPE, is_training ) logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False") fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( - dtype, config, False, qkv_format, input_layernorm - ) - - tols = dict(atol=5e-1, rtol=5e-1) - rmse_tol = 0.1 - fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16) - fwd_range = max(fused_attn_fwd_fp8.max().item(), fused_attn_fwd_f16.max().item()) - min( - fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.min().item() + dtype, config, False, qkv_format, input_layernorm, RoPE, is_training ) + atol = 5e-1 + rtol = 5e-1 + rmse_tol = 0.15 logging.debug("========== {:^25s} ==========".format("forward output")) - logging.debug( - "fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format( - fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item() + if _flash_attn_3_plus and not is_training: + _error( + flash_attn_fwd_fp8, + fused_attn_fwd_f16, + "flash_attn_fwd_fp8", + "fused_attn_fwd_f16", + atol, + rtol, + rmse_tol, ) + _error( + fused_attn_fwd_fp8, + fused_attn_fwd_f16, + "fused_attn_fwd_fp8", + "fused_attn_fwd_f16", + atol, + rtol, + rmse_tol, ) - logging.debug( - "fused_attn_fwd_f16 min {:.6f} max {:.6f}".format( - fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item() - ) - ) - logging.debug("fused_attn_fwd RMSE: {:.6f}".format(fwd_rmse)) - try: - torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols) - except Exception as e: - logging.debug(e) - assert ( - fwd_rmse < rmse_tol * fwd_range - ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range - ) - for i in range(len(param_names[:1])): - bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i]) - bwd_range = max( - fused_attn_bwd_fp8[i].max().item(), fused_attn_bwd_f16[i].max().item() - ) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_f16[i].min().item()) - - logging.debug("========== {:^25s} ==========".format(param_names[i])) - logging.debug( - "fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}".format( - i, fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item() - ) - ) - logging.debug( - "fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}".format( - i, fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item() + if is_training: + for i in range(len(param_names[:1])): + logging.debug("========== {:^25s} ==========".format(param_names[i])) + _error( + fused_attn_bwd_fp8[i], + fused_attn_bwd_f16[i], + f"fused_attn_bwd_fp8[{i}]", + f"fused_attn_bwd_f16[{i}]", + atol, + rtol, + rmse_tol, ) - ) - logging.debug("fused_attn_bwd RMSE[{}]: {:.6f}".format(i, bwd_rmse)) - try: - torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols) - except Exception as e: - logging.debug(e) - - assert ( - bwd_rmse < rmse_tol * bwd_range - ), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range - ) -def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): +def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training): reset_rng_states() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) @@ -1418,6 +1432,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: ) with fp8_model_init(enabled=fp8_mha): + rotary_pos_emb = None + if RoPE: + PE = RotaryPositionEmbedding(dim=config.head_dim_qk) + rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda") mha = MultiheadAttention( hidden_size=config.hidden_size, num_attention_heads=config.num_heads, @@ -1434,6 +1452,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: qkv_weight_interleaved=True, qkv_format=qkv_format, ).to(dtype=dtype, device="cuda") + if not is_training: + mha = mha.eval() seqlens_q = torch.full( [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" @@ -1464,7 +1484,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: tensor_shape = [dim_to_num[j] for j in layout.split("_")] tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda") hidden_states = tensor.view(*tensor.shape[:-2], -1) - hidden_states.requires_grad = True + if is_training: + hidden_states.requires_grad = True tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda") out_grad = tensor.view(*tensor.shape[:-2], -1) @@ -1475,8 +1496,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: checkpoint_core_attention=False, core_attention_bias_type=config.attn_bias_type, is_first_microbatch=None, + rotary_pos_emb=rotary_pos_emb, ) - out.backward(out_grad) + if is_training: + out.backward(out_grad) param_names = [] param_names.append("hidden_states.grad") @@ -1487,7 +1510,9 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: param_names.append(name + ".grad") params.append(param) - return out, param_names, tuple(x.grad for x in params) + if is_training: + return out, param_names, tuple(x.grad for x in params) + return out, param_names, tuple(None for x in params) @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") @@ -1497,7 +1522,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: @pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys()) @pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) -def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd): +@pytest.mark.parametrize("is_training", [True, False]) +def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): config = model_configs_fp8_vs_f16[model] if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: @@ -1505,76 +1531,68 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd): os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" - global _attention_backends - _attention_backends["backend_selection_requires_update"] = True + if _flash_attn_3_plus and not is_training: + os.environ["NVTE_FLASH_ATTN"] = "1" + os.environ["NVTE_FUSED_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True") + flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( + dtype, config, True, qkv_layout, is_training + ) + + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + _attention_backends["backend_selection_requires_update"] = True logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True") - fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(dtype, config, True, qkv_layout) + fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( + dtype, config, True, qkv_layout, is_training + ) logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False") - fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(dtype, config, False, qkv_layout) + fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( + dtype, config, False, qkv_layout, is_training + ) - tols = dict(atol=5e-1, rtol=5e-2) + atol = 5e-1 + rtol = 5e-2 rmse_tol = 0.1 bwd_names = ["dq", "dk", "dv"] - fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16) - fwd_range = max(fused_attn_fwd_fp8.max().item(), fused_attn_fwd_f16.max().item()) - min( - fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.min().item() - ) - logging.debug("========== {:^25s} ==========".format("forward output")) - logging.debug( - "fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format( - fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item() - ) - ) - logging.debug( - "fused_attn_fwd_f16 min {:.6f} max {:.6f}".format( - fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item() + if _flash_attn_3_plus and not is_training: + _error( + flash_attn_fwd_fp8, + fused_attn_fwd_f16, + "flash_attn_fwd_fp8", + "fused_attn_fwd_f16", + atol, + rtol, + rmse_tol, ) + _error( + fused_attn_fwd_fp8, + fused_attn_fwd_f16, + "fused_attn_fwd_fp8", + "fused_attn_fwd_f16", + atol, + rtol, + rmse_tol, ) - logging.debug("fused_attn_fwd RMSE: {:.6f}".format(fwd_rmse)) - try: - torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols) - except Exception as e: - logging.debug(e) - - assert ( - fwd_rmse < rmse_tol * fwd_range - ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range - ) - for i, _ in enumerate(fused_attn_bwd_f16): - bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i]) - bwd_range = max( - fused_attn_bwd_fp8[i].max().item(), fused_attn_bwd_f16[i].max().item() - ) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_f16[i].min().item()) - - logging.debug("========== {:^25s} ==========".format(bwd_names[i])) - logging.debug( - "fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}".format( - i, fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item() - ) - ) - logging.debug( - "fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}".format( - i, fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item() + if is_training: + for i, _ in enumerate(fused_attn_bwd_f16): + logging.debug("========== {:^25s} ==========".format(bwd_names[i])) + _error( + fused_attn_bwd_fp8[i], + fused_attn_bwd_f16[i], + f"fused_attn_bwd_fp8[{i}]", + f"fused_attn_bwd_f16[{i}]", + atol, + rtol, + rmse_tol, ) - ) - logging.debug("fused_attn_bwd RMSE[{}]: {:.6f}".format(i, bwd_rmse)) - try: - torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols) - except Exception as e: - logging.debug(e) - - assert ( - bwd_rmse < rmse_tol * bwd_range - ), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range - ) -def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout): +def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training): reset_rng_states() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() @@ -1607,6 +1625,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: attention_type="self", qkv_format=qkv_format, ).to(dtype=dtype, device="cuda") + if not is_training: + dpa = dpa.eval() seqlens_q = torch.full( [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" @@ -1680,9 +1700,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: core_attention_bias_type=config.attn_bias_type, is_first_microbatch=True, ) - out.backward(out_grad) + if is_training: + out.backward(out_grad) - return out, (inp[0].grad, inp[1].grad, inp[2].grad) + if is_training: + return out, (inp[0].grad, inp[1].grad, inp[2].grad) + return out, (None, None, None) model_configs_fp8 = { @@ -1726,58 +1749,26 @@ def test_custom_mha_fp8_vs_f16(dtype, model): fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention") unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention") - tols = dict(atol=5e-1, rtol=5e-1) + atol = 5e-1 + rtol = 5e-1 rmse_tol = 0.1 - fwd_rmse = _rmse(fused_attn_fwd_fp8, unfused_attn_fwd_f16) - fwd_range = max(fused_attn_fwd_fp8.max().item(), unfused_attn_fwd_f16.max().item()) - min( - fused_attn_fwd_fp8.min().item(), unfused_attn_fwd_f16.min().item() - ) - bwd_rmse = _rmse(fused_attn_bwd_fp8, unfused_attn_bwd_f16) - bwd_range = max(fused_attn_bwd_fp8.max().item(), unfused_attn_bwd_f16.max().item()) - min( - fused_attn_bwd_fp8.min().item(), unfused_attn_bwd_f16.min().item() - ) - - logging.debug( - "fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format( - fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item() - ) - ) - logging.debug( - "unfused_attn_fwd_f16 min {:.6f} max {:.6f}".format( - unfused_attn_fwd_f16.min().item(), unfused_attn_fwd_f16.max().item() - ) - ) - logging.debug("fused_attn_fwd_fp8 vs unfused_attn_fwd_f16 RMSE: {:.6f}".format(fwd_rmse)) - try: - torch.testing.assert_close(fused_attn_fwd_fp8, unfused_attn_fwd_f16, **tols) - except Exception as e: - logging.debug(e) - - logging.debug( - "fused_attn_bwd_fp8 min {:.6f} max {:.6f}".format( - fused_attn_bwd_fp8.min().item(), fused_attn_bwd_fp8.max().item() - ) - ) - logging.debug( - "unfused_attn_bwd_f16 min {:.6f} max {:.6f}".format( - unfused_attn_bwd_f16.min().item(), unfused_attn_bwd_f16.max().item() - ) - ) - logging.debug("fused_attn_bwd_fp8 vs unfused_attn_bwd_f16 RMSE: {:.6f}".format(bwd_rmse)) - try: - torch.testing.assert_close(fused_attn_bwd_fp8, unfused_attn_bwd_f16, **tols) - except Exception as e: - logging.debug(e) - - assert ( - fwd_rmse < rmse_tol * fwd_range - ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range + _error( + fused_attn_fwd_fp8, + unfused_attn_fwd_f16, + "fused_attn_fwd_fp8", + "unfused_attn_fwd_f16", + atol, + rtol, + rmse_tol, ) - assert ( - bwd_rmse < rmse_tol * bwd_range - ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range + _error( + fused_attn_bwd_fp8, + unfused_attn_bwd_f16, + "fused_attn_bwd_fp8", + "unfused_attn_bwd_f16", + atol, + rtol, + rmse_tol, ) @@ -1791,7 +1782,6 @@ def _run_custom_mha_fp8(dtype, config, backend): os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" - global _attention_backends _attention_backends["backend_selection_requires_update"] = True inp = 0.0001 * torch.randint( @@ -1846,7 +1836,6 @@ def _run_ref_mha_f16(dtype, config, backend): os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" - global _attention_backends _attention_backends["backend_selection_requires_update"] = True inp = torch.load("qkv.pt").to(device="cuda") @@ -1996,12 +1985,18 @@ def forward( None, None, None, - fp8_meta["scaling_fwd"].scale_inv[META_QKV], - fp8_meta["scaling_fwd"].scale_inv[META_S], - fp8_meta["scaling_fwd"].scale[META_S], - fp8_meta["scaling_fwd"].scale[META_O], - fp8_meta["scaling_fwd"].amax_history[0][META_S], - fp8_meta["scaling_fwd"].amax_history[0][META_O], + fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv + META_QKV, # d_scale_qkv_offset + fp8_meta["scaling_fwd"].scale_inv, # d_scale_s + META_S, # d_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_s + META_S, # q_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_o + META_O, # q_scale_o_offset + fp8_meta["scaling_fwd"].amax_history, # amax_s + META_S, # amax_s_offset + fp8_meta["scaling_fwd"].amax_history, # amax_o + META_O, # amax_o_offset attn_scale=None, dropout=p_dropout, fast_zero_fill=fast_zero_fill, diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 82875e2791..c1c18ffe47 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -22,10 +22,16 @@ "cp_1_2": ModelConfig( 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) ), # MHA - "cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA - "cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA + "cp_1_3": ModelConfig( + 2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512) + ), # MHA + "cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA + "cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA "cp_2_2": ModelConfig( - 2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) + 2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) + ), # GQA + "cp_2_3": ModelConfig( + 2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512) ), # GQA } @@ -45,31 +51,32 @@ def get_bash_arguments(**kwargs): @pytest.mark.parametrize("dtype", ["bf16", "fp16"]) @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) -@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"]) +@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"]) def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): config = model_configs_flash_attn[model] + if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1): + pytest.skip("CP implementation with KV P2P does not support sliding window yet!") if cp_comm_type == "all_gather" and qkv_format == "thd": - pytest.skip( - f"CP implementation with KV all-gather does not support {qkv_format} format yet!" - ) - if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type: - pytest.skip( - f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask" - " type yet!" - ) + pytest.skip("CP implementation with KV all-gather does not support THD format yet!") if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": + pytest.skip("CP implementation with KV all-gather does not support bias yet!") + if cp_comm_type == "a2a" and qkv_format == "thd": + pytest.skip("CP implementation with QKVO A2A does not support THD format yet!") + if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias": + pytest.skip("CP implementation with QKVO A2A does not support bias yet!") + if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): pytest.skip( - f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias" - " type yet!" - ) - if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1): - pytest.skip( - f"CP implementation with KV P2P does not support window size {config.window_size} yet!" + f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" + f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" ) subprocess.run( get_bash_arguments( - dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FlashAttention" + dtype=dtype, + model=model, + qkv_format=qkv_format, + kernel_backend="FlashAttention", + cp_comm_type=cp_comm_type, ), check=True, ) @@ -81,10 +88,16 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA "cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA "cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA - "cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA - "cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA - "cp_2_2": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA - "cp_2_3": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA + "cp_1_4": ModelConfig( + 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) + ), # MHA + "cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA + "cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA + "cp_2_2": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA + "cp_2_3": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA + "cp_2_4": ModelConfig( + 2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) + ), # GQA } @@ -93,37 +106,29 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"]) @pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) -@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"]) +@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"]) def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): if qkv_format == "thd" and get_device_compute_capability() < (9, 0): - pytest.skip("THD format is only supported on sm90+.") + pytest.skip("THD format is only supported on sm90+!") if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0): - pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0") + pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") + if dtype == "fp8" and get_device_compute_capability() < (9, 0): + pytest.skip("FP8 attention is only supported on sm90+!") config = model_configs_fused_attn[model] if qkv_format == "thd" and config.num_heads != config.num_gqa_groups: - pytest.skip(f"{qkv_format} format does not support QGA/MQA yet!") + pytest.skip("THD format does not support QGA/MQA yet!") if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": - pytest.skip(f"{qkv_format} format does not support {config.attn_bias_type} bias type yet!") - if cp_comm_type == "all_gather" and qkv_format == "thd": - pytest.skip( - f"CP implementation with KV all-gather does not support {qkv_format} format yet!" - ) - if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type: - pytest.skip( - f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask" - " type yet!" - ) - if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": - pytest.skip( - f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias" - " type yet!" - ) - if config.window_size != (-1, 0) and config.window_size != (-1, -1): + pytest.skip("THD format does not support post_scale_bias yet!") + if qkv_format == "thd" and cp_comm_type == "all_gather": + pytest.skip("CP implementation with KV all-gather does not support THD format yet!") + if qkv_format == "thd" and cp_comm_type == "a2a": + pytest.skip("CP implementation with QKVO A2A does not support THD format yet!") + if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type != "a2a": pytest.skip( - "Fused attention does not support sliding window attention + context parallelism yet!" + "Sliding window attention only can be supported with the implementation of QKVO A2A!" ) - if cp_comm_type == "all_gather" and dtype == "fp8": + if dtype == "fp8" and cp_comm_type == "all_gather": pytest.skip( "CP implementation with KV all-gather does not support FP8 + context parallelism yet!" ) @@ -131,10 +136,25 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): pytest.skip("FP8 attention cannot work with THD format yet!") if dtype == "fp8" and config.attn_bias_type != "no_bias": pytest.skip("FP8 attention cannot work with bias yet!") + if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1): + pytest.skip("FP8 attention cannot work with sliding window yet!") + if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": + pytest.skip("CP implementation with KV all-gather does not support bias yet!") + if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias": + pytest.skip("CP implementation with QKVO A2A does not support bias yet!") + if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): + pytest.skip( + f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" + f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" + ) subprocess.run( get_bash_arguments( - dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention" + dtype=dtype, + model=model, + qkv_format=qkv_format, + kernel_backend="FusedAttention", + cp_comm_type=cp_comm_type, ), check=True, ) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 0ea0319771..fd204f58c4 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -293,7 +293,7 @@ def test_transpose( with pytest.raises(AssertionError): torch.testing.assert_close(x_fp8_t, x, **tols) - # Caching test. + # Caching test assert x_fp8._transpose_invalid, "Transpose cache must be invalid when not caching." x_fp8 += 0.5 x = x_fp8.from_float8() @@ -302,14 +302,13 @@ def test_transpose( torch.testing.assert_close(x_fp8_t, x_t, **tols) assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly." - # Inplace update test. + # Inplace update test x_fp8 += 0.5 - assert x_fp8._transpose_invalid, "Transpose cache not invalidated properly." + assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly." x = x_fp8.from_float8() - x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(fill_cache=True)) + x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8._transpose) x_t = x.transpose(0, 1) torch.testing.assert_close(x_fp8_t, x_t, **tols) - assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly." def test_serialization( self, diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 3523e1cda5..e97dfe1efd 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -88,10 +88,7 @@ def make_reference_and_test_tensors( ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) test = ref.to(device=test_device, dtype=test_dtype) if test_is_fp8: - test = Float8Tensor.to_float8(test) - test._transpose = test._data.reshape(-1, test.size(-1)).transpose(0, 1) - test._transpose = test._transpose.contiguous() - test._transpose_invalid = False + test = Float8Tensor.to_float8(test, with_transpose_cache=True) elif test.data_ptr() == ref.data_ptr(): test = test.clone() ref.copy_(test) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 85cd4fc256..ad34b4996f 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -7,6 +7,7 @@ from typing import Dict, List, Optional import pytest import copy +import random import torch import torch.nn as nn @@ -30,6 +31,8 @@ TransformerLayer, LayerNorm, InferenceParams, + Fp8Padding, + Fp8Unpadding, ) from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm @@ -354,6 +357,40 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return (input > 0) * input * input +class TorchGroupedLinearWithPadding(nn.Module): + + def __init__( + self, num_gemms, in_features, out_features, bias, params_dtype, parallel_mode, fp8 + ) -> None: + super().__init__() + + self.padding = Fp8Padding(num_gemms) + self.linear_fn = GroupedLinear( + num_gemms, + in_features, + out_features, + bias=bias, + params_dtype=params_dtype, + parallel_mode=parallel_mode, + device="cuda", + ) + self.unpadding = Fp8Unpadding(num_gemms) + + self.fp8 = fp8 + + def forward(self, inp: torch.Tensor, m_splits: List[int]) -> torch.Tensor: + if self.fp8: + orig_m_splits = m_splits + inp, m_splits = self.padding(inp, m_splits) + + out = self.linear_fn(inp, m_splits) + + if self.fp8: + out = self.unpadding(out, orig_m_splits) + + return out + + _supported_act = { "geglu": nn.GELU(approximate="tanh"), "gelu": nn.GELU(approximate="tanh"), @@ -1229,12 +1266,15 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False ) inp_hidden_states.retain_grad() - m = config.seq_len // 16 - dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() - dist.append(dist[-1]) # Manually add a zero - m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) - m_splits = m_splits * 16 - assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms + if num_gemms > 1: + m = config.seq_len // 16 + dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() + dist.append(dist[-1]) # Manually add a zero + m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) + m_splits = m_splits * 16 + assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms + else: + m_splits = torch.tensor([config.seq_len]) with fp8_autocast(enabled=fp8): if isinstance(block, GroupedLinear): @@ -1316,7 +1356,7 @@ def test_grouped_linear_accuracy( @pytest.mark.parametrize("parallel_mode", ["column", "row"]) def test_grouped_linear_accuracy_parallel_mode(parallel_mode): - """Split the tests to reduce CI time""" + """Split the tests to save CI time""" test_grouped_linear_accuracy( dtype=torch.float32, num_gemms=6, @@ -1328,6 +1368,170 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode): ) +def test_grouped_linear_accuracy_single_gemm(): + """Split the tests to save CI time""" + test_grouped_linear_accuracy( + dtype=torch.float32, + num_gemms=1, + bs=2, + model=list(model_configs.keys())[0], + fp8=True, + fp8_model_params=True, + ) + + +def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False): + + def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): + """Padding tensor shapes to multiples of 16.""" + padded_tokens_per_expert = [ + (num_tokens + 15) // 16 * 16 for num_tokens in tokens_per_expert + ] + hidden_states = torch.split(hidden_states, tokens_per_expert) + padded_hidden_states = [] + for hidden_state, actual_num_tokens, padded_num_tokens in zip( + hidden_states, tokens_per_expert, padded_tokens_per_expert + ): + padded_hidden_states.append(hidden_state) + if padded_num_tokens > actual_num_tokens: + pad_tensor = torch.zeros( + padded_num_tokens - actual_num_tokens, + hidden_state.shape[1], + dtype=hidden_state.dtype, + device=hidden_state.device, + ) + padded_hidden_states.append(pad_tensor) + padded_hidden_states = torch.cat(padded_hidden_states, dim=0) + return padded_hidden_states, padded_tokens_per_expert + + def _unpad_tensor_for_fp8(padded_hidden_states, actual_tokens_per_expert, tokens_per_expert): + inputmats = torch.split( + padded_hidden_states.view(-1, padded_hidden_states.shape[-1]), tokens_per_expert + ) + hidden_states = torch.cat( + [ + grad_output_mat[: actual_tokens_per_expert[i]] + for i, grad_output_mat in enumerate(inputmats) + ], + dim=0, + ) + + return hidden_states + + def _generate_random_numbers(n, total_sum): + if n <= 0: + return [] + + # reset seed + random.seed(seed) + + breaks = sorted(random.sample(range(1, total_sum), n - 1)) + random_numbers = ( + [breaks[0]] + + [breaks[i] - breaks[i - 1] for i in range(1, n - 1)] + + [total_sum - breaks[-1]] + ) + + return random_numbers + + reset_rng_states() + if fp8: + FP8GlobalStateManager.reset() + + inp_hidden_states = torch.randn( + (config.seq_len * bs, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) + inp_hidden_states.retain_grad() + + m_splits = _generate_random_numbers(num_gemms, config.seq_len * bs) + + with fp8_autocast(enabled=fp8): + if isinstance(block, TorchGroupedLinearWithPadding): + out = block(inp_hidden_states, m_splits) + else: + if fp8: + padded_inp_hidden_states, padding_m_splits = _pad_tensor_for_fp8( + inp_hidden_states, m_splits + ) + padded_inp_hidden_states = block(padded_inp_hidden_states, padding_m_splits) + out = _unpad_tensor_for_fp8(padded_inp_hidden_states, m_splits, padding_m_splits) + else: + out = block(inp_hidden_states, m_splits) + + loss = out.sum() + loss.backward() + + torch.cuda.synchronize() + outputs = [out, inp_hidden_states.grad] + for p in block.parameters(): + if p.requires_grad: + outputs.append(p.grad) + return outputs + + +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("num_gemms", [3, 6]) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("fp8", [True]) +@pytest.mark.parametrize("fp8_model_params", all_boolean) +def test_padding_grouped_linear_accuracy( + dtype, num_gemms, bs, model, fp8, fp8_model_params, parallel_mode=None +): + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + + config = model_configs[model] + if config.seq_len % 16 != 0 and fp8: + pytest.skip("FP8 requires sequence length to be divisible by 16.") + + with fp8_model_init(enabled=fp8 and fp8_model_params): + grouped_linear = TorchGroupedLinearWithPadding( + num_gemms, + config.hidden_size, + 4 * config.hidden_size, + bias=False, + params_dtype=dtype, + parallel_mode=parallel_mode, + fp8=fp8, + ).eval() + + with fp8_model_init(enabled=fp8 and fp8_model_params): + ref_grouped_linear = GroupedLinear( + num_gemms, + config.hidden_size, + 4 * config.hidden_size, + bias=False, + params_dtype=dtype, + parallel_mode=parallel_mode, + device="cuda", + ).eval() + + # Share params + with torch.no_grad(): + inner_grouped_linear = grouped_linear.linear_fn + for i in range(num_gemms): + setattr( + ref_grouped_linear, + f"weight{i}", + Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()), + ) + + outputs = _test_padding_grouped_linear_accuracy( + grouped_linear, num_gemms, bs, dtype, config, fp8 + ) + outputs_ref = _test_padding_grouped_linear_accuracy( + ref_grouped_linear, num_gemms, bs, dtype, config, fp8 + ) + + # Shoule be bit-wise match + for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + + def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph): reset_rng_states() @@ -1845,7 +2049,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate): fp8_grouped_gemm( A_fp8, - scale_inv, + [scale_inv], 0, # A_offset tex.DType.kFloat8E4M3, B_fp8, diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py new file mode 100644 index 0000000000..ed25b96955 --- /dev/null +++ b/tests/pytorch/test_permutation.py @@ -0,0 +1,513 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch +import pytest +from typing import Dict, List + +from transformer_engine.pytorch import moe_permute as te_permute, moe_unpermute as te_unpermute +from transformer_engine.pytorch.utils import is_bf16_compatible +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.float8_tensor import Float8Tensor +import transformer_engine_torch as tex + + +seed = 1234 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) + + +def pytorch_permute(tokens, indices, num_out_tokens: int = None): + """ + Permute the tokens based on the indices. Token with the same index will be grouped together. + The input indices shape is [tokens, top_k], it indicates which experts were selected by each token separately. + + Args: + tokens: torch.Tensor + The input token tensor. + indices: torch.Tensor + The token to expert indices tensor, should have a shape of [num_tokens] or [num_tokens, topk]. + num_out_tokens: int, optional + The effective output token count, when enabling the capacity factor, should equal the number of tokens not dropped. + By default, set to None, meaning no tokens are dropped. + + Returns: + torch.Tensor: + The permuted tensor. + torch.Tensor: + The sorted_indices corresponding permuted tensor. + """ + if indices.dim() == 1: + topk = 1 + else: + topk = indices.size(1) + flatten_indices = indices.view(-1) + sorted_indices = torch.argsort(flatten_indices, stable=True) + num_out_tokens = num_out_tokens if num_out_tokens is not None else flatten_indices.size(0) + + permuted_tokens = tokens.index_select(0, sorted_indices[:num_out_tokens] // topk) + return permuted_tokens, sorted_indices + + +def pytorch_unpermute( + permuted_tokens: torch.Tensor, + sorted_indices: torch.Tensor, + probs: torch.Tensor = None, +): + """ + Unpermute a tensor of permuted tokens based on sorted indices, and optionally merge the tokens with their + corresponding probabilities. + + Args: + permuted_tokens: torch.Tensor + The tensor of permuted tokens to be unpermuted. + sorted_indices: torch.Tensor + The tensor of sorted indices used to unpermute the tokens. + probs: torch.Tensor, optional + The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will + be merged with their respective probabilities. + + Returns: + torch.Tensor: + The unpermuted tokens, optionally merged with probabilities. + """ + + if probs is not None: + # Unpermute and merge the tokens with their probabilities + num_unpermuted_tokens = probs.numel() + topk = probs.size(1) + else: + # Unpermute the tokens without merge + num_unpermuted_tokens = sorted_indices.size(0) + topk = 1 + unpermuted_tokens = torch.zeros( + [num_unpermuted_tokens, permuted_tokens.shape[-1]], + dtype=permuted_tokens.dtype, + device=permuted_tokens.device, + ) + + unpermuted_tokens.index_copy_(0, sorted_indices[: permuted_tokens.size(0)], permuted_tokens) + unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1)) + if probs is not None: + unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1) + unpermuted_tokens = unpermuted_tokens.sum(dim=1) + return unpermuted_tokens + + +def dtype_tols(te_dtype: tex.DType) -> Dict[str, float]: + """Estimated tolerances for a datatype + + Based on tolerances for torch.testing.assert_close. + + """ + if te_dtype == tex.DType.kFloat32: + return dict(rtol=1.0e-6, atol=1.0e-6) + if te_dtype == tex.DType.kFloat16: + return dict(rtol=3.0e-3, atol=1.0e-5) + if te_dtype == tex.DType.kBFloat16: + return dict(rtol=2.0e-2, atol=1.0e-5) + if te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3: + return dict(rtol=2.0e-1, atol=1.0e-1) + raise ValueError(f"Unsuppored dtype ({te_dtype})") + + +def _test_permutation( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, + with_probs, + BENCHMARK=False, +): + if not with_probs and topK > 1: + pytest.skip("Only permutations with topK=1 and without probabilities are supported.") + + if topK > num_expert: + pytest.skip("topK should be smaller than the number of experts.") + + if num_out_tokens == None: + num_out_tokens = num_tokens * topK + + print( + f"token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}" + ) + + fp8 = False + # Convert TE dtypes to PyTorch dtypes + if te_dtype == tex.DType.kFloat32: + dtype = torch.float32 + elif te_dtype == tex.DType.kFloat16: + dtype = torch.float16 + elif te_dtype == tex.DType.kBFloat16: + dtype = torch.bfloat16 + elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3): + dtype = torch.uint8 + fp8 = True + else: + pytest.skip("Invalid dtype.") + + if fp8: + permute_fwd_input = torch.rand( + size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + permute_bwd_input = torch.rand( + size=(num_out_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + unpermute_bwd_input = torch.rand( + size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + + permute_fwd_input = Float8Tensor.to_float8( + permute_fwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + ) + permute_bwd_input = Float8Tensor.to_float8( + permute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + ) + unpermute_bwd_input = Float8Tensor.to_float8( + unpermute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + ) + + pytorch_permute_fwd_input = permute_fwd_input.from_float8(torch.float16) + pytorch_permute_bwd_input = permute_bwd_input.from_float8(torch.float16) + pytorch_unpermute_bwd_input = unpermute_bwd_input.from_float8(torch.float16) + else: + pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() + pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + + pytorch_permute_fwd_input.requires_grad_(True) + + if num_tokens > 0: + indices = torch.stack([torch.randperm(num_expert)[:topK] for _ in range(num_tokens)]) + else: + indices = torch.empty((num_tokens, topK)) + indices = indices.to(torch.int32).cuda() + + probs = None + if with_probs: + probs = torch.rand(num_tokens, topK).cuda() + row_sums = probs.sum(dim=1, keepdim=True) + probs = probs / row_sums + probs.requires_grad_(True) + + ################################################################################################################################### + # + # PyTorch Permutation + # + ################################################################################################################################### + pytorch_permute_output, sorted_indices = pytorch_permute( + pytorch_permute_fwd_input, indices, num_out_tokens + ) + pytorch_permute_output.backward(pytorch_permute_bwd_input, retain_graph=True) + + pytorch_unpermute_fwd_input = pytorch_permute_output.detach() + pytorch_unpermute_fwd_input.requires_grad_(True) + + pytorch_unpermute_output = pytorch_unpermute( + pytorch_unpermute_fwd_input, sorted_indices, probs=probs + ) + pytorch_unpermute_output.backward(pytorch_unpermute_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # TE Permutation + # + ################################################################################################################################### + te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach() + te_permute_fwd_input.requires_grad_(True) + te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach() + + te_permute_output, row_id_map = te_permute(te_permute_fwd_input, indices, num_out_tokens) + te_permute_output.backward(te_permute_bwd_input, retain_graph=True) + + te_probs = None + if with_probs: + te_probs = probs.detach() + te_probs.requires_grad_(True) + te_unpermute_fwd_input = te_permute_output.detach() + te_unpermute_fwd_input.requires_grad_(True) + te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() + + te_unpermute_output = te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs) + te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # Results Check + # + ################################################################################################################################### + tols = dtype_tols(te_dtype) + + if fp8: + te_permute_output_ = te_permute_output.from_float8(torch.float32) + te_permute_fwd_input_grad = te_permute_fwd_input.grad.from_float8(torch.float32) + te_unpermute_output_ = te_unpermute_output.from_float8(torch.float32) + te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.from_float8(torch.float32) + else: + te_permute_output_ = te_permute_output.float() + te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() + te_unpermute_output_ = te_unpermute_output.float() + te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float() + + torch.testing.assert_close( + pytorch_permute_output.float(), + te_permute_output_, + msg=f"Mismatch in te_permute fwd", + ) + torch.testing.assert_close( + pytorch_permute_fwd_input.grad.float(), + te_permute_fwd_input_grad, + msg=f"Mismatch in te_permute bwd", + **tols, + ) + torch.testing.assert_close( + pytorch_unpermute_output.float(), + te_unpermute_output_, + msg=f"Mismatch in te_unpermute fwd", + **tols, + ) + torch.testing.assert_close( + pytorch_unpermute_fwd_input.grad.float(), + te_unpermute_fwd_input_grad, + msg=f"Mismatch in te_unpermute bwd", + **tols, + ) + if with_probs: + torch.testing.assert_close( + probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols + ) + + if not pytorch_permute_fwd_input.numel(): + print("Empty pytorch_permute_fwd_input activation test passed.") + return + + ################################################################################################################################### + # + # Benchmark + # + ################################################################################################################################### + def backward_wrapper( + act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False + ): + # Set forward_input.grad to None to avoid grad accumulation. + if accumulate_grad == False: + for i in forward_input: + i.grad = None + return act.backward(backward_input, retain_graph=retain_graph) + + if BENCHMARK: + t1 = perf_test_cuda_kernel( + lambda: pytorch_permute(pytorch_permute_fwd_input, indices, num_out_tokens) + ) + t2 = perf_test_cuda_kernel( + lambda: te_permute(te_permute_fwd_input, indices, num_out_tokens) + ) + print(f"permute\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: backward_wrapper( + pytorch_permute_output, + pytorch_permute_bwd_input, + forward_input=[pytorch_permute_fwd_input], + retain_graph=True, + accumulate_grad=False, + ) + ) + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + te_permute_output, + te_permute_bwd_input, + forward_input=[te_permute_fwd_input], + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"permute\t\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: pytorch_unpermute(pytorch_unpermute_fwd_input, sorted_indices, probs=probs) + ) + t2 = perf_test_cuda_kernel( + lambda: te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs) + ) + print(f"unpermute\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: backward_wrapper( + pytorch_unpermute_output, + pytorch_unpermute_bwd_input, + forward_input=( + [pytorch_unpermute_fwd_input, probs] + if with_probs + else [pytorch_unpermute_fwd_input] + ), + retain_graph=True, + accumulate_grad=False, + ) + ) + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + te_unpermute_output, + te_unpermute_bwd_input, + forward_input=( + [te_unpermute_fwd_input, te_probs] if with_probs else [te_unpermute_fwd_input] + ), + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") + + +def perf_test_cuda_kernel(cuda_kernel_fn): + if torch.cuda.is_available(): + # create CUDA event + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # warmup + for _ in range(50): + cuda_kernel_fn() + + start_event.record() + for _ in range(100): + cuda_kernel_fn() + end_event.record() + torch.cuda.synchronize() + + elapsed_time_ms = start_event.elapsed_time(end_event) + return elapsed_time_ms / 100 + else: + pytest.skip("CUDA is not available.") + + +# TE tensor dtypes +_te_dtypes: List[tex.DType] = [tex.DType.kFloat32, tex.DType.kFloat16] +if is_bf16_compatible(): + _te_dtypes.append(tex.DType.kBFloat16) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_tokens", [4096]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("num_out_tokens", [None, 2039]) +def test_permutation( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, +): + with_probs = True + BENCHMARK = False + + _test_permutation( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=BENCHMARK, + ) + + +# Only run FP8 tests on H100. +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) +@pytest.mark.parametrize("num_tokens", [2048]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("num_out_tokens", [None, 2039]) +def test_permutation_fp8( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, +): + with_probs = True + BENCHMARK = False + + _test_permutation( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=BENCHMARK, + ) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_tokens", [4096]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +def test_permutation_topk1_no_probs( + te_dtype, + num_tokens, + num_expert, + hidden_size, +): + topK = 1 + num_out_tokens = None + with_probs = False + BENCHMARK = False + + _test_permutation( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=BENCHMARK, + ) + + +def test_permutation_single_case(): + print("GPU:", torch.cuda.get_device_name(0)) + + # te_dtype = tex.DType.kFloat32 + # te_dtype = tex.DType.kFloat16 + # te_dtype = tex.DType.kBFloat16 + te_dtype = tex.DType.kFloat8E5M2 + # te_dtype = tex.DType.kFloat8E4M3 + + num_tokens = 10 + num_expert = 4 + hidden_size = 16 + topK = 2 + num_out_tokens = num_tokens * topK - 1 + with_probs = True + Benchmark = True + + _test_permutation( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=Benchmark, + ) + + +if __name__ == "__main__": + test_permutation_single_case() diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 58bd4f828c..647d2c474d 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -48,6 +48,7 @@ set(transformer_engine_SOURCES) list(APPEND transformer_engine_SOURCES pycudnn.cpp transformer_engine.cpp + common.cu transpose/cast_transpose.cu transpose/transpose.cu transpose/cast_transpose_fusion.cu @@ -65,10 +66,12 @@ list(APPEND transformer_engine_SOURCES layer_norm/ln_api.cpp layer_norm/ln_bwd_semi_cuda_kernel.cu layer_norm/ln_fwd_cuda_kernel.cu + permutation/permutation.cu rmsnorm/rmsnorm_api.cpp rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu rmsnorm/rmsnorm_fwd_cuda_kernel.cu util/cast.cu + util/padding.cu util/cuda_driver.cpp util/cuda_runtime.cpp util/rtc.cpp diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 46cfa9176a..4bcd1f8e27 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -47,20 +47,20 @@ def _get_sys_extension(): def _load_cudnn(): """Load CUDNN shared library.""" - + # Attempt to locate cuDNN in Python dist-packages lib_path = glob.glob( os.path.join( sysconfig.get_path("purelib"), f"nvidia/cudnn/lib/libcudnn.{_get_sys_extension()}.*[0-9]", ) ) - if lib_path: assert ( len(lib_path) == 1 ), f"Found {len(lib_path)} libcudnn.{_get_sys_extension()}.x in nvidia-cudnn-cuXX." return ctypes.CDLL(lib_path[0], mode=ctypes.RTLD_GLOBAL) + # Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH") if cudnn_home: libs = glob.glob(f"{cudnn_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True) @@ -68,13 +68,14 @@ def _load_cudnn(): if libs: return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) - cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") - if cuda_home: - libs = glob.glob(f"{cuda_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True) - libs.sort(reverse=True, key=os.path.basename) - if libs: - return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) + # Attempt to locate cuDNN in CUDA_HOME, CUDA_PATH or /usr/local/cuda + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" + libs = glob.glob(f"{cuda_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True) + libs.sort(reverse=True, key=os.path.basename) + if libs: + return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) + # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) @@ -91,14 +92,15 @@ def _load_library(): def _load_nvrtc(): """Load NVRTC shared library.""" - cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") - if cuda_home: - libs = glob.glob(f"{cuda_home}/**/libnvrtc.{_get_sys_extension()}*", recursive=True) - libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs)) - libs.sort(reverse=True, key=os.path.basename) - if libs: - return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) - + # Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" + libs = glob.glob(f"{cuda_home}/**/libnvrtc.{_get_sys_extension()}*", recursive=True) + libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs)) + libs.sort(reverse=True, key=os.path.basename) + if libs: + return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) + + # Attempt to locate NVRTC via ldconfig libs = subprocess.check_output("ldconfig -p | grep 'libnvrtc'", shell=True) libs = libs.decode("utf-8").split("\n") sos = [] @@ -109,6 +111,8 @@ def _load_nvrtc(): sos.append(lib.split(">")[1].strip()) if sos: return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL) + + # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise return ctypes.CDLL(f"libnvrtc.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index fc93705dff..6184e235bd 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -27,7 +27,8 @@ void act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) { reinterpret_cast(input.data.dptr), reinterpret_cast(output->data.dptr), reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), tot_elts, {}, + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), tot_elts, {}, stream);); // NOLINT(*) ); // NOLINT(*) } @@ -50,7 +51,8 @@ void dact_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream reinterpret_cast(input.data.dptr), reinterpret_cast(output->data.dptr), reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), tot_elts, {}, + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), tot_elts, {}, stream);); // NOLINT(*) ); // NOLINT(*) } @@ -74,7 +76,8 @@ void gated_act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) { reinterpret_cast(input.data.dptr), reinterpret_cast(output->data.dptr), reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), output->data.shape[0], + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), output->data.shape[0], output->data.shape[1], {}, stream);); // NOLINT(*) ); // NOLINT(*) diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu new file mode 100644 index 0000000000..4e95fc24de --- /dev/null +++ b/transformer_engine/common/common.cu @@ -0,0 +1,32 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "./common.h" +#include "./utils.cuh" + +namespace transformer_engine { + +namespace { + +__global__ void __launch_bounds__(1) + update_tensor_scale_inv_kernel(const float* __restrict__ scale_ptr, + float* __restrict__ scale_inv_ptr) { + const float scale = scale_ptr == nullptr ? 1 : *scale_ptr; + reciprocal(scale_inv_ptr, scale); +} + +} // namespace + +void update_tensor_scale_inv(Tensor* t, cudaStream_t stream) { + if (t->scale_inv.dptr != nullptr) { + update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>( + reinterpret_cast(t->scale.dptr), reinterpret_cast(t->scale_inv.dptr)); + } +} + +} // namespace transformer_engine diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 42b529f388..7e72e1b031 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -262,6 +262,13 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt bool is_fp8_dtype(const DType t); +/*! \brief Update a tensor's FP8 scale-inverse + * + * The FP8 scale-inverse (dequantization scaling factor) is updated + * with the reciprocal of the FP8 scale (quantization scaling factor). + */ +void update_tensor_scale_inv(Tensor *t, cudaStream_t stream); + #define NVTE_API_CALL(api_name) \ transformer_engine::nvtx::NVTXWrapper _##api_name##_nvtx_wrapper(#api_name); diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index c9b57752e2..593ec086d7 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -255,7 +255,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, "Unable to find suitable cuBLAS GEMM algorithm"); NVTE_CHECK_CUBLAS(status); - if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms"); + if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms"); // D = alpha * (A * B) + beta * C NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, @@ -269,6 +269,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, workspace, /* workspace */ workspaceSize, stream)); /* stream */ + // Update FP8 scale-inv in output tensor + if (is_fp8_dtype(outputD->data.dtype)) { + update_tensor_scale_inv(outputD, stream); + } + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc)); diff --git a/transformer_engine/common/include/transformer_engine/padding.h b/transformer_engine/common/include/transformer_engine/padding.h new file mode 100644 index 0000000000..a419b38234 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/padding.h @@ -0,0 +1,51 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file padding.h + * \brief Functions handling padding. + */ + +#ifndef TRANSFORMER_ENGINE_PADDING_H_ +#define TRANSFORMER_ENGINE_PADDING_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Padding multiple tensors. + * + * NOTE: Padding mode only support bottom. + * + * For example, 3x3 matrix pad to 4x3 matrix. + * + * source + * | 1 | 2 | 3 | + * | 4 | 5 | 6 | + * | 7 | 8 | 9 | + * + * destination + * | 1 | 2 | 3 | + * | 4 | 5 | 6 | + * | 7 | 8 | 9 | + * | 0 | 0 | 0 | + * + * \param[in] num_tensors Number of tensors. + * \param[in] input_list List of 2D input tensors. + * \param[in,out] output_list List of padded tensors. Dimensions + * match tensors in input_list. + * \param[in] padded_num_rows_list List of padded num rows corresponding to input tensors. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list, + const int* padded_num_rows_list, cudaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_PADDING_H_ diff --git a/transformer_engine/common/include/transformer_engine/permutation.h b/transformer_engine/common/include/transformer_engine/permutation.h new file mode 100644 index 0000000000..c6263bf87e --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/permutation.h @@ -0,0 +1,21 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_PERMUTATION_H_ +#define TRANSFORMER_ENGINE_PERMUTATION_H_ + +#include "transformer_engine.h" + +void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor sorted_row_id, + NVTETensor row_id_map, const NVTETensor prob, NVTETensor prob_grad, + const NVTETensor input_fwd, const int num_rows, const int topK, + const int num_cols, const int num_out_tokens, cudaStream_t stream = nullptr); + +void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id_map, + const NVTETensor prob, const int num_rows, const int topK, const int num_cols, + cudaStream_t stream = nullptr); + +#endif // TRANSFORMER_ENGINE_PERMUTATION_H_ diff --git a/transformer_engine/common/layer_norm/ln.h b/transformer_engine/common/layer_norm/ln.h index 45839ed75b..13543a10aa 100644 --- a/transformer_engine/common/layer_norm/ln.h +++ b/transformer_engine/common/layer_norm/ln.h @@ -89,6 +89,9 @@ struct FwdParams : public ParamsBase { // AMax output void *amax; + // Inverse of scaling factor + void *scale_inv; + // Whether to compute scale and amax bool fp8_out; }; diff --git a/transformer_engine/common/layer_norm/ln_api.cpp b/transformer_engine/common/layer_norm/ln_api.cpp index 115422e94e..8a40450e59 100644 --- a/transformer_engine/common/layer_norm/ln_api.cpp +++ b/transformer_engine/common/layer_norm/ln_api.cpp @@ -196,6 +196,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size params.epsilon = epsilon; params.amax = z->amax.dptr; params.scale = z->scale.dptr; + params.scale_inv = z->scale_inv.dptr; params.fp8_out = fp8_out; params.zero_centered_gamma = zero_centered_gamma; diff --git a/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh b/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh index 9fe4c16373..bd3741d1d1 100644 --- a/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh +++ b/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh @@ -132,10 +132,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( } } if (params.fp8_out) { - amax = reduce_max(amax, warp); - if (threadIdx.x == 0 && threadIdx.y == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(reinterpret_cast(params.amax), amax); + // Reduce amax over block + if (params.amax != nullptr) { + amax = reduce_max(amax, warp); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(reinterpret_cast(params.amax), amax); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { + reciprocal(reinterpret_cast(params.scale_inv), scale); } } } @@ -291,10 +299,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne // Finalize fp8 factors if (params.fp8_out) { - amax = reduce_max(amax, warp); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(reinterpret_cast(params.amax), amax); + // Reduce amax over block + if (params.amax != nullptr) { + amax = reduce_max(amax, warp); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(reinterpret_cast(params.amax), amax); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { + reciprocal(reinterpret_cast(params.scale_inv), scale); } } } diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu new file mode 100644 index 0000000000..2b894fbfdc --- /dev/null +++ b/transformer_engine/common/permutation/permutation.cu @@ -0,0 +1,369 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "../common.h" + +static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id_map, + const int num_rows, const int topK, + const int num_out_tokens) { + // Each block corresponds to one source token + // row_id_map[topK][num_rows] + const int bid = blockIdx.x; + const int tid = threadIdx.x; + const int idx = bid * blockDim.x + tid; + + if (idx >= num_rows * topK) return; + + int source_row = sorted_row_id[idx]; + int source_token_id = source_row / topK; + int source_topK_id = source_row % topK; + + if (idx >= num_out_tokens) { + // Set the indices of dropped tokens to -1 + row_id_map[source_topK_id * num_rows + source_token_id] = -1; + } else { + // Create a row id map for subsequent unpermute operation + row_id_map[source_topK_id * num_rows + source_token_id] = idx; + } +} + +template +__global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const int *row_id_map, + const float *prob, const int num_rows, const int topK, + const int num_cols) { + extern __shared__ int8_t s_mem[]; + TCompute *s_prob = reinterpret_cast(s_mem); + + // Each block corresponds to one dest token + const int source_token = blockIdx.x; + const int tid = threadIdx.x; + + if (hasProb) { + for (int i = tid; i < topK; i += blockDim.x * blockDim.y) { + // Load all the topK probs related to the source row into smem + s_prob[i] = TCompute(prob[source_token * topK + i]); + } + __syncthreads(); + } + + // Register buffers for vector type (float4) memory access + float4 frag_load_store; + T *frag_load_store_ptr = reinterpret_cast(&frag_load_store); + + // Number of elemments in frag_load_store + static constexpr int kElementsPerAccess = 16 / sizeof(T); + + // Traverse along the hidden dimention + for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) { + TCompute frag_elem[kElementsPerAccess]; + TCompute frag_sum[kElementsPerAccess]; + + int source_row = row_id_map[source_token]; + + // source_row == -1 represents a dropped token + if (source_row != -1) { + const T *source_row_ptr = input + source_row * num_cols; + + frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i)); + + for (int e = 0; e < kElementsPerAccess; e++) { + frag_sum[e] = TCompute(frag_load_store_ptr[e]); + } + + if (hasProb) { + for (int e = 0; e < kElementsPerAccess; e++) { + frag_sum[e] = frag_sum[e] * s_prob[0]; + } + } + } else { + for (int e = 0; e < kElementsPerAccess; e++) { + frag_sum[e] = TCompute(0.0f); + } + } + + for (int k = 1; k < topK; k++) { + source_row = row_id_map[k * num_rows + source_token]; + + if (source_row == -1) continue; + + const T *source_row_ptr = input + source_row * num_cols; + + frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i)); + + for (int e = 0; e < kElementsPerAccess; e++) { + frag_elem[e] = TCompute(frag_load_store_ptr[e]); + } + + if (hasProb) { + for (int e = 0; e < kElementsPerAccess; e++) { + frag_elem[e] = frag_elem[e] * s_prob[k]; + } + } + + for (int e = 0; e < kElementsPerAccess; e++) { + frag_sum[e] = frag_sum[e] + frag_elem[e]; + } + } + + T *dest_row_ptr = unpermuted_output + source_token * num_cols; + + for (int e = 0; e < kElementsPerAccess; e++) { + if constexpr ((std::is_same_v || std::is_same_v) && + (!hasProb)) { + frag_sum[e] = frag_sum[e] / TCompute(topK); + } + frag_load_store_ptr[e] = T(frag_sum[e]); + } + + *reinterpret_cast(dest_row_ptr + i) = frag_load_store; + } +} + +template +__global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *act_grad, + const float *prob, float *prob_grad, const int *row_id_map, + const int num_rows, const int topK, const int num_cols) { + extern __shared__ int8_t s_mem[]; + TCompute *s_prob = reinterpret_cast(s_mem); + + // Each block corresponds to one source token + const int source_token = blockIdx.x; + const int tid = threadIdx.x; + + if (hasProb) { + for (int i = tid; i < topK; i += blockDim.x) { + // Load all the topK probs related to the source row into smem + s_prob[i] = TCompute(prob[source_token * topK + i]); + } + __syncthreads(); + } + + // Accumulators for the calculation of prob_grad + float accum[topKTile] = {0.0f}; + + // Register buffers for vector type (float4) memory access + float4 frag_load_store; + T *frag_load_store_ptr = reinterpret_cast(&frag_load_store); + + // Number of elemments in frag_load_store + static constexpr int kElementsPerAccess = 16 / sizeof(T); + + // The starting address of each source row + const T *source_row_ptr = input_bwd + source_token * num_cols; + + // Traverse along the hidden dimention + for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) { + TCompute frag_src[kElementsPerAccess]; + + frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i)); + + for (int e = 0; e < kElementsPerAccess; e++) frag_src[e] = TCompute(frag_load_store_ptr[e]); + + int index = source_token; + + // Process each row in the corresponding topK rows + for (int k = 0; k < topKTile; k++) { + if (k == topK) break; + + int dest_row = row_id_map[index]; + index += num_rows; + + if (dest_row != -1) { + if (hasProb) { + // Calculate act_grad in unpermute bwd + for (int e = 0; e < kElementsPerAccess; e++) + frag_load_store_ptr[e] = T(frag_src[e] * s_prob[k]); + } else { + // permute fwd + for (int e = 0; e < kElementsPerAccess; e++) frag_load_store_ptr[e] = T(frag_src[e]); + } + + T *dest_row_ptr = act_grad + dest_row * num_cols; + *reinterpret_cast(dest_row_ptr + i) = frag_load_store; + + if (hasProb) { + // Inner product calculation for prob_grad in unpermute bwd + const T *input_fwd_ptr = input_fwd + dest_row * num_cols; + + frag_load_store = __ldlu(reinterpret_cast(input_fwd_ptr + i)); + + TCompute frag_input_fwd[kElementsPerAccess]; + for (int e = 0; e < kElementsPerAccess; e++) + frag_input_fwd[e] = TCompute(frag_load_store_ptr[e]); + + for (int e = 0; e < kElementsPerAccess; e++) { + accum[k] += static_cast(frag_src[e] * frag_input_fwd[e]); + } + } + } + } + } + + if (hasProb) { + for (int k = 0; k < topKTile; k++) { + if (k == topK) break; + // Warp-level reduction + for (int mask = 16; mask > 0; mask /= 2) { + accum[k] = accum[k] + __shfl_xor_sync(0xffffffff, accum[k], mask, 32); + } + } + + if (tid == 0) { + for (int k = 0; k < topKTile; k++) { + if (k == topK) break; + prob_grad[source_token * topK + k] = accum[k]; + } + } + } +} + +template +void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, int *row_id_map, + const float *prob, float *prob_grad, const T *input_fwd, + const int num_rows, const int topK, const int num_cols, + const int num_out_tokens, cudaStream_t stream) { + using TCompute = typename std::conditional<(std::is_same::value || + std::is_same::value), + half, T>::type; + + static constexpr int kElementsPerAccess = 16 / sizeof(T); + + if (input_fwd == nullptr) { + // moe_permute_fwd + + int threads = 64; + int blocks = (num_rows * topK + threads - 1) / threads; + + moe_permute_row_map<<>>(sorted_row_id, row_id_map, num_rows, topK, + num_out_tokens); + + blocks = num_rows; + threads = std::min(num_cols / kElementsPerAccess, 1024); + moe_permute_kernel<<>>( + input, nullptr, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); + } else { + // moe_unpermute_bwd + + int threads = 32; + int blocks = num_rows; + + if (prob == nullptr) { + // moe_unpermute_bwd without probs + + moe_permute_kernel<<>>( + input, input_fwd, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); + } else { + // moe_unpermute_bwd with probs + + size_t smem_bytes = topK * sizeof(TCompute); + + if (topK <= 8) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else if (topK <= 16) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else if (topK <= 32) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else if (topK <= 64) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else if (topK <= 128) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else { + NVTE_ERROR("topK cannot exceed 128."); + } + } + } +} + +template +void nvte_unpermute_launcher(const T *input, T *output, int *row_id_map, const float *prob, + const int num_rows, const int topK, const int num_cols, + cudaStream_t stream) { + using TCompute = typename std::conditional<(std::is_same::value || + std::is_same::value), + half, T>::type; + + static constexpr int kElementsPerAccess = 16 / sizeof(T); + + int blocks = num_rows; + int threads = std::min(num_cols / kElementsPerAccess, 1024); + size_t smem_bytes = topK * sizeof(TCompute); + + if (prob == nullptr) { + // moe_permute_bwd + // moe_unpermute_fwd without probs + + moe_unpermute_kernel<<>>( + input, output, row_id_map, nullptr, num_rows, topK, num_cols); + } else { + // moe_unpermute_fwd with probs + + moe_unpermute_kernel<<>>( + input, output, row_id_map, prob, num_rows, topK, num_cols); + } +} + +void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor sorted_row_id, + NVTETensor row_id_map, const NVTETensor prob, NVTETensor prob_grad, + const NVTETensor input_fwd, const int num_rows, const int topK, + const int num_cols, const int num_out_tokens, cudaStream_t stream) { + NVTE_API_CALL(nvte_permute); + + const transformer_engine::Tensor *input_cu = + reinterpret_cast(input); + const transformer_engine::Tensor *output_cu = + reinterpret_cast(output); + const transformer_engine::Tensor *sorted_row_id_cu = + reinterpret_cast(sorted_row_id); + const transformer_engine::Tensor *row_id_map_cu = + reinterpret_cast(row_id_map); + const transformer_engine::Tensor *prob_cu = + reinterpret_cast(prob); + const transformer_engine::Tensor *prob_grad_cu = + reinterpret_cast(prob_grad); + const transformer_engine::Tensor *input_fwd_cu = + reinterpret_cast(input_fwd); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + input_cu->data.dtype, T, + nvte_permute_launcher(reinterpret_cast(input_cu->data.dptr), + reinterpret_cast(output_cu->data.dptr), + reinterpret_cast(sorted_row_id_cu->data.dptr), + reinterpret_cast(row_id_map_cu->data.dptr), + reinterpret_cast(prob_cu->data.dptr), + reinterpret_cast(prob_grad_cu->data.dptr), + reinterpret_cast(input_fwd_cu->data.dptr), num_rows, topK, + num_cols, num_out_tokens, stream);); +} + +void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id_map, + const NVTETensor prob, const int num_rows, const int topK, const int num_cols, + cudaStream_t stream) { + NVTE_API_CALL(nvte_unpermute); + + const transformer_engine::Tensor *input_cu = + reinterpret_cast(input); + const transformer_engine::Tensor *output_cu = + reinterpret_cast(output); + const transformer_engine::Tensor *row_id_map_cu = + reinterpret_cast(row_id_map); + const transformer_engine::Tensor *prob_cu = + reinterpret_cast(prob); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + input_cu->data.dtype, T, + nvte_unpermute_launcher(reinterpret_cast(input_cu->data.dptr), + reinterpret_cast(output_cu->data.dptr), + reinterpret_cast(row_id_map_cu->data.dptr), + reinterpret_cast(prob_cu->data.dptr), num_rows, topK, + num_cols, stream);); +} diff --git a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/rmsnorm/rmsnorm_api.cpp index e9a6ff483d..9b143b2f85 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/rmsnorm/rmsnorm_api.cpp @@ -159,6 +159,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens params.epsilon = epsilon; params.amax = z->amax.dptr; params.scale = z->scale.dptr; + params.scale_inv = z->scale_inv.dptr; params.fp8_out = fp8_out; params.zero_centered_gamma = zero_centered_gamma; diff --git a/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh b/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh index a1cfc2293c..c435ae3744 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh +++ b/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh @@ -125,10 +125,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke } } if (params.fp8_out) { - amax = reduce_max(amax, warp); - if (threadIdx.x == 0 && threadIdx.y == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(reinterpret_cast(params.amax), amax); + // Reduce amax over block + if (params.amax != nullptr) { + amax = reduce_max(amax, warp); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(reinterpret_cast(params.amax), amax); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { + reciprocal(reinterpret_cast(params.scale_inv), scale); } } } @@ -267,10 +275,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ // Finalize fp8 factors if (params.fp8_out) { - amax = reduce_max(amax, warp); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(reinterpret_cast(params.amax), amax); + // Reduce amax over block + if (params.amax != nullptr) { + amax = reduce_max(amax, warp); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(reinterpret_cast(params.amax), amax); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { + reciprocal(reinterpret_cast(params.scale_inv), scale); } } } diff --git a/transformer_engine/common/transpose/cast_transpose.cu b/transformer_engine/common/transpose/cast_transpose.cu index 6cbd4daade..dd45d0a668 100644 --- a/transformer_engine/common/transpose/cast_transpose.cu +++ b/transformer_engine/common/transpose/cast_transpose.cu @@ -101,14 +101,11 @@ struct KernelConfig { }; template -__global__ void __launch_bounds__(block_size) - cast_transpose_general_kernel(const IType *__restrict__ const input, - const CType *__restrict__ const noop, - OType *__restrict__ const output_c, - OType *__restrict__ const output_t, - const CType *__restrict__ const scale_ptr, - CType *__restrict__ const amax_ptr, const size_t row_length, - const size_t num_rows) { +__global__ void __launch_bounds__(block_size) cast_transpose_general_kernel( + const IType *__restrict__ const input, const CType *__restrict__ const noop, + OType *__restrict__ const output_c, OType *__restrict__ const output_t, + const CType *__restrict__ const scale_ptr, CType *__restrict__ const amax_ptr, + CType *__restrict__ const scale_inv_ptr, const size_t row_length, const size_t num_rows) { if (noop != nullptr && noop[0] == 1.0f) return; // Vectorized load/store sizes @@ -207,9 +204,15 @@ __global__ void __launch_bounds__(block_size) if (amax_ptr != nullptr) { amax = reduce_max(amax, tidy); if (threadIdx.x == 0) { + static_assert(std::is_same::value); atomicMaxFloat(amax_ptr, amax); } } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv_ptr != nullptr) { + reciprocal(scale_inv_ptr, scale); + } } } // namespace @@ -255,6 +258,8 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output "Cast and transposed outputs need to share amax tensor."); NVTE_CHECK(cast_output.scale.dptr == transposed_output.scale.dptr, "Cast and transposed outputs need to share scale tensor."); + NVTE_CHECK(cast_output.scale_inv.dptr == transposed_output.scale_inv.dptr, + "Cast and transposed outputs need to share scale-inverse tensor."); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, InputType, @@ -324,7 +329,9 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output static_cast(cast_output.data.dptr), static_cast(transposed_output.data.dptr), static_cast(cast_output.scale.dptr), - static_cast(cast_output.amax.dptr), row_length, num_rows); + static_cast(cast_output.amax.dptr), + static_cast(cast_output.scale_inv.dptr), row_length, + num_rows); } else { // Statically-compiled general kernel constexpr size_t load_size = 4; constexpr size_t store_size = 4; @@ -339,7 +346,8 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output static_cast(cast_output.data.dptr), static_cast(transposed_output.data.dptr), static_cast(cast_output.scale.dptr), - static_cast(cast_output.amax.dptr), row_length, num_rows); + static_cast(cast_output.amax.dptr), + static_cast(cast_output.scale_inv.dptr), row_length, num_rows); }); // NOLINT(*) ); // NOLINT(*) } diff --git a/transformer_engine/common/transpose/cast_transpose_fusion.cu b/transformer_engine/common/transpose/cast_transpose_fusion.cu index d96757990f..a8361d57ea 100644 --- a/transformer_engine/common/transpose/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/cast_transpose_fusion.cu @@ -433,15 +433,19 @@ __global__ void __launch_bounds__(cast_transpose_num_threads) } } - /* warp tile amax reduce*/ - amax = reduce_max(amax, warp_id); - - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - if (param.amax != nullptr) { + // Reduce amax over block + if (param.amax != nullptr) { + amax = reduce_max(amax, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); atomicMaxFloat(param.amax, amax); } } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && param.scale_inv != nullptr) { + reciprocal(param.scale_inv, scale); + } } static const char *ActTypeToString[] = { @@ -870,17 +874,18 @@ __global__ void __launch_bounds__(cast_transpose_num_threads) __syncthreads(); } - /* warp tile amax reduce*/ - max = reduce_max(max, warp_id); - - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - if (amax != nullptr) { + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); atomicMaxFloat(amax, max); } - if (scale_inv != nullptr) { - reciprocal(scale_inv, scale); - } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, scale); } } @@ -1079,17 +1084,18 @@ __global__ void __launch_bounds__(cast_transpose_num_threads) __syncthreads(); } - /* warp tile amax reduce*/ - max = reduce_max(max, warp_id); - - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - if (amax != nullptr) { + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); atomicMaxFloat(amax, max); } - if (scale_inv != nullptr) { - reciprocal(scale_inv, scale); - } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, scale); } } diff --git a/transformer_engine/common/transpose/multi_cast_transpose.cu b/transformer_engine/common/transpose/multi_cast_transpose.cu index 8e6e90a7bf..4026016519 100644 --- a/transformer_engine/common/transpose/multi_cast_transpose.cu +++ b/transformer_engine/common/transpose/multi_cast_transpose.cu @@ -36,6 +36,8 @@ struct MultiCastTransposeArgs { void* scale_list[kMaxTensorsPerKernel]; // (output) AMAX's of input tensors void* amax_list[kMaxTensorsPerKernel]; + // (output) Inverse of scaling factor for output tensors + void* scale_inv_list[kMaxTensorsPerKernel]; // Input matrix heights int num_rows_list[kMaxTensorsPerKernel]; // Input matrix widths @@ -82,7 +84,8 @@ __global__ void __launch_bounds__(threads_per_block) OType* output_t = reinterpret_cast(args.output_t_list[tensor_id]); const CType* scale_ptr = reinterpret_cast(args.scale_list[tensor_id]); const CType scale = scale_ptr == nullptr ? 1 : *scale_ptr; - CType* amax = reinterpret_cast(args.amax_list[tensor_id]); + CType* amax_ptr = reinterpret_cast(args.amax_list[tensor_id]); + CType* scale_inv_ptr = reinterpret_cast(args.scale_inv_list[tensor_id]); const int num_rows = args.num_rows_list[tensor_id]; const int row_length = args.row_length_list[tensor_id]; @@ -183,7 +186,10 @@ __global__ void __launch_bounds__(threads_per_block) local_amax = reduce_max(local_amax, tidy); if (tid == 0) { static_assert(std::is_same::value); - if (amax != nullptr) atomicMaxFloat(amax, local_amax); + if (amax_ptr != nullptr) atomicMaxFloat(amax_ptr, local_amax); + } + if (tile_id == 0 && tid == 0 && scale_inv_ptr != nullptr) { + reciprocal(scale_inv_ptr, scale); } } @@ -285,6 +291,7 @@ void multi_cast_transpose(const std::vector input_list, kernel_args.output_t_list[pos] = transposed_output_list[tensor_id]->data.dptr; kernel_args.scale_list[pos] = cast_output_list[tensor_id]->scale.dptr; kernel_args.amax_list[pos] = cast_output_list[tensor_id]->amax.dptr; + kernel_args.scale_inv_list[pos] = cast_output_list[tensor_id]->scale_inv.dptr; kernel_args.num_rows_list[pos] = num_rows; kernel_args.row_length_list[pos] = row_length; kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles; diff --git a/transformer_engine/common/transpose/rtc/cast_transpose.cu b/transformer_engine/common/transpose/rtc/cast_transpose.cu index 6ea8326147..07244a42e9 100644 --- a/transformer_engine/common/transpose/rtc/cast_transpose.cu +++ b/transformer_engine/common/transpose/rtc/cast_transpose.cu @@ -25,7 +25,7 @@ __global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel( const IType* __restrict__ const input, const CType* __restrict__ const noop, OType* __restrict__ const output_c, OType* __restrict__ const output_t, const CType* __restrict__ const scale_ptr, CType* __restrict__ const amax_ptr, - const size_t row_length, const size_t num_rows) { + CType* __restrict__ const scale_inv_ptr, const size_t row_length, const size_t num_rows) { if (noop != nullptr && noop[0] == 1.0f) return; // Vectorized load/store sizes @@ -121,4 +121,9 @@ __global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel( atomicMaxFloat(amax_ptr, amax); } } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv_ptr != nullptr) { + reciprocal(scale_inv_ptr, scale); + } } diff --git a/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu b/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu index c005be98ef..4ba1cb4c69 100644 --- a/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu @@ -229,12 +229,16 @@ __global__ void __launch_bounds__(BLOCK_SIZE) } } - // warp tile amax reduce - const CType max_block = reduce_max(amax, warp_id); - - if (threadIdx.x == 0) { - if (param.amax != nullptr) { + // Reduce amax over block + if (param.amax != nullptr) { + const CType max_block = reduce_max(amax, warp_id); + if (threadIdx.x == 0) { atomicMaxFloat(param.amax, max_block); } } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && param.scale_inv != nullptr) { + reciprocal(param.scale_inv, scale); + } } diff --git a/transformer_engine/common/util/cast.cu b/transformer_engine/common/util/cast.cu index 8dd2b98ebf..dd03afd21b 100644 --- a/transformer_engine/common/util/cast.cu +++ b/transformer_engine/common/util/cast.cu @@ -46,7 +46,8 @@ void fp8_quantize(const Tensor &input, Tensor *output, cudaStream_t stream) { reinterpret_cast(input.data.dptr), reinterpret_cast(output->data.dptr), reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), N, {}, + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, stream);); // NOLINT(*) ); // NOLINT(*) } @@ -68,7 +69,7 @@ void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { p.scale_inv = reinterpret_cast(input.scale_inv.dptr); VectorizedUnaryKernelLauncher( reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), nullptr, nullptr, N, p, + reinterpret_cast(output->data.dptr), nullptr, nullptr, nullptr, N, p, stream);); // NOLINT(*) ); // NOLINT(*) } diff --git a/transformer_engine/common/util/padding.cu b/transformer_engine/common/util/padding.cu new file mode 100644 index 0000000000..017d2e6a56 --- /dev/null +++ b/transformer_engine/common/util/padding.cu @@ -0,0 +1,219 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include +#include +#include + +#include "../common.h" +#include "../utils.cuh" + +namespace transformer_engine { + +namespace { + +// Parameters to tune +constexpr int n_warps_per_tile = 4; +constexpr int threads_per_block = THREADS_PER_WARP * n_warps_per_tile; +constexpr int desired_load_store_size = 8; +constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB + +struct MultiPaddingArgs { + // (input) Data buffers for input tensors + void* input_list[kMaxTensorsPerKernel]; + // (output) Data buffers for cast output tensors + void* output_list[kMaxTensorsPerKernel]; + // Input matrix heights + int num_rows_list[kMaxTensorsPerKernel]; + // Input matrix heights (padded) + int padded_num_rows_list[kMaxTensorsPerKernel]; + // Input matrix widths + int row_length_list[kMaxTensorsPerKernel]; + // tensor + int block_range[kMaxTensorsPerKernel + 1]; + // Number of tensors being processed by kernel + int num_tensors; +}; + +template +__global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiPaddingArgs args) { + using Vec = Vec; + + // Thread indices + // Note: Block is interpreted as a warp_size x num_warps grid + constexpr int bdimx = THREADS_PER_WARP; + constexpr int bdimy = n_warps_per_tile; + const int tid = threadIdx.x; + const int tidx = tid % bdimx; + const int tidy = tid / bdimx; + const int bid = blockIdx.x; + + // Input tensors are divided into tiles + // Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles + constexpr int tile_dim_m = THREADS_PER_WARP * nvec; + constexpr int tile_dim_n = THREADS_PER_WARP * nvec; + + // Number of nvec x nvec subtiles for each thread to + // load/store + constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile; + + // Find tensor corresponding to block + int tensor_id = 0; + while (args.block_range[tensor_id + 1] <= bid) { + ++tensor_id; + } + const Type* input = reinterpret_cast(args.input_list[tensor_id]); + Type* output = reinterpret_cast(args.output_list[tensor_id]); + const int num_rows = args.num_rows_list[tensor_id]; + const int padded_num_rows = args.padded_num_rows_list[tensor_id]; + const int row_length = args.row_length_list[tensor_id]; + + // Find position of tile within tensor + const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; + const int tile_id = bid - args.block_range[tensor_id]; + const int tile_id_m = tile_id / num_tiles_n; + const int tile_id_n = tile_id % num_tiles_n; + const int tile_row = tile_id_m * tile_dim_m; + const int tile_col = tile_id_n * tile_dim_n; + + // Load input and store to registers + // Note: Each thread loads n_iterations subtiles, casts to output + // type, and transposes in registers. + Type local_zero = static_cast(0.f); +#pragma unroll + for (int iter = 0; iter < n_iterations; ++iter) { + const int i1 = tidy + iter * bdimy; + const int j1 = tidx; +#pragma unroll + for (int i2 = 0; i2 < nvec; ++i2) { + const int row = tile_row + i1 * nvec + i2; + const int col = tile_col + j1 * nvec; + Vec local_input; + Vec local_output; + local_input.clear(); + if (row < num_rows) { + for (int j2 = 0; j2 < nvec; ++j2) { + if (col + j2 < row_length) { + local_input.data.elt[j2] = input[row * row_length + col + j2]; + } + } + } +#pragma unroll + for (int j2 = 0; j2 < nvec; ++j2) { + local_output.data.elt[j2] = local_input.data.elt[j2]; + } + if (row < num_rows) { + for (int j2 = 0; j2 < nvec; ++j2) { + if (col + j2 < row_length) { + output[row * row_length + col + j2] = local_output.data.elt[j2]; + } + } + } else if (row < padded_num_rows) { + // padding + for (int j2 = 0; j2 < nvec; ++j2) { + if (col + j2 < row_length) { + output[row * row_length + col + j2] = local_zero; + } + } + } + } + } +} + +} // namespace + +void multi_padding(const std::vector input_list, std::vector output_list, + const std::vector padded_num_rows_list, cudaStream_t stream) { + // Check that number of tensors is valid + NVTE_CHECK(output_list.size() == input_list.size(), + "Number of input and output tensors must match"); + if (input_list.empty()) { + return; + } + + // Check that tensor properties are valid + DType type = input_list[0]->data.dtype; + for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { + const auto& input = *input_list[tensor_id]; + const auto& output = *output_list[tensor_id]; + CheckInputTensor(input, "multi_padding_input_" + std::to_string(tensor_id)); + CheckInputTensor(output, "multi_padding_output_" + std::to_string(tensor_id)); + + NVTE_CHECK(input.data.dtype == type, "Input tensor types do not match."); + NVTE_CHECK(output.data.dtype == type, "Output tensor types do not match."); + + NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions."); + NVTE_CHECK(output.data.shape[0] == padded_num_rows_list[tensor_id], + "output tensor shape does not match padded input shape."); + } + + // Input matrices are divided into tiles + // Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles + const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size / typeToSize(type); + const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size / typeToSize(type); + + // Add tensors to kernel argument struct + MultiPaddingArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.block_range[0] = 0; + for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { + // Launch kernel if argument struct is full + if (kernel_args.num_tensors == kMaxTensorsPerKernel) { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type); + const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; + multi_padding_kernel + <<>>(kernel_args);); // NOLINT(*) + kernel_args.num_tensors = 0; + } + + // Calculate number of thread blocks needed for tensor + const int num_rows = input_list[tensor_id]->data.shape[0]; + const int padded_num_rows = padded_num_rows_list[tensor_id]; + const int row_length = input_list[tensor_id]->data.shape[1]; + const int num_tiles_m = (padded_num_rows + tile_dim_m - 1) / tile_dim_m; + const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; + const int num_tiles = num_tiles_m * num_tiles_n; + + // Add tensor to kernel argument struct + const int pos = kernel_args.num_tensors; + kernel_args.input_list[pos] = const_cast(input_list[tensor_id]->data.dptr); + kernel_args.output_list[pos] = output_list[tensor_id]->data.dptr; + kernel_args.num_rows_list[pos] = num_rows; + kernel_args.padded_num_rows_list[pos] = padded_num_rows; + kernel_args.row_length_list[pos] = row_length; + kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles; + kernel_args.num_tensors++; + } + + // Launch kernel + if (kernel_args.num_tensors > 0) { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type); + const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; + multi_padding_kernel + <<>>(kernel_args);); // NOLINT(*) + } +} + +} // namespace transformer_engine + +void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list, + const int* padded_num_rows_list, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_padding); + using namespace transformer_engine; + std::vector input_list_, output_list_; + std::vector padded_num_rows_list_; + for (size_t i = 0; i < num_tensors; ++i) { + input_list_.push_back(reinterpret_cast(const_cast(input_list[i]))); + output_list_.push_back(reinterpret_cast(output_list[i])); + padded_num_rows_list_.push_back(padded_num_rows_list[i]); + } + multi_padding(input_list_, output_list_, padded_num_rows_list_, stream); +} diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index 63ad1857cf..8653bf45a4 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -168,12 +168,12 @@ template __launch_bounds__(unary_kernel_threads) __global__ void unary_kernel(const InputType *input, OutputType *output, const ComputeType *scale, - ComputeType *amax, Param p, const size_t N, + ComputeType *amax, ComputeType *scale_inv, Param p, const size_t N, const size_t num_aligned_elements) { VectorizedLoader loader(input, N); VectorizedStorer storer(output, N); ComputeType max = 0; - ComputeType s = 0; + ComputeType s = 1; if constexpr (is_fp8::value) { if (scale != nullptr) s = *scale; } @@ -199,12 +199,18 @@ __launch_bounds__(unary_kernel_threads) __global__ storer.store(tid, N); } if constexpr (is_fp8::value) { - /* warp tile amax reduce*/ - max = reduce_max(max, warp_id); + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); + } + } - if (threadIdx.x == 0 && amax != nullptr) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); } } } @@ -214,13 +220,13 @@ template __launch_bounds__(unary_kernel_threads) __global__ void unary_grad_kernel(const InputTypeGrad *grad, const InputType *input, OutputType *output, - const ComputeType *scale, ComputeType *amax, Param p, const size_t N, - const size_t num_aligned_elements) { + const ComputeType *scale, ComputeType *amax, ComputeType *scale_inv, + Param p, const size_t N, const size_t num_aligned_elements) { VectorizedLoader loader(input, N); VectorizedLoader grad_loader(grad, N); VectorizedStorer storer(output, N); ComputeType max = 0; - ComputeType s = 0; + ComputeType s = 1; if constexpr (is_fp8::value) { if (scale != nullptr) s = *scale; } @@ -248,12 +254,18 @@ __launch_bounds__(unary_kernel_threads) __global__ storer.store(tid, N); } if constexpr (is_fp8::value) { - /* warp tile amax reduce*/ - max = reduce_max(max, warp_id); + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); + } + } - if (threadIdx.x == 0 && amax != nullptr) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); } } } @@ -311,7 +323,7 @@ Alignment CheckAlignment(const size_t lead_dim, const int nvec, const T... ptrs) template void VectorizedUnaryKernelLauncher(const InputType *input, OutputType *output, const fp32 *scale, - fp32 *amax, const size_t N, const Param params, + fp32 *amax, fp32 *scale_inv, const size_t N, const Param params, cudaStream_t stream) { if (N != 0) { auto align = CheckAlignment(N, nvec, input, output); @@ -325,16 +337,16 @@ void VectorizedUnaryKernelLauncher(const InputType *input, OutputType *output, c switch (align) { case Alignment::SAME_ALIGNED: unary_kernel<<>>( - input, output, scale, amax, params, N, num_aligned_elements); + input, output, scale, amax, scale_inv, params, N, num_aligned_elements); break; case Alignment::SAME_UNALIGNED: unary_kernel<<>>( - input, output, scale, amax, params, N, num_aligned_elements); + input, output, scale, amax, scale_inv, params, N, num_aligned_elements); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize - unary_kernel<1, true, fp32, Param, OP> - <<>>(input, output, scale, amax, params, N, N); + unary_kernel<1, true, fp32, Param, OP><<>>( + input, output, scale, amax, scale_inv, params, N, N); break; } } @@ -345,7 +357,8 @@ template void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputType *input, OutputType *output, const fp32 *scale, fp32 *amax, - const size_t N, const Param params, cudaStream_t stream) { + fp32 *scale_inv, const size_t N, const Param params, + cudaStream_t stream) { if (N != 0) { auto align = CheckAlignment(N, nvec, input, grad, output); @@ -358,16 +371,16 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputTyp switch (align) { case Alignment::SAME_ALIGNED: unary_grad_kernel<<>>( - grad, input, output, scale, amax, params, N, num_aligned_elements); + grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements); break; case Alignment::SAME_UNALIGNED: unary_grad_kernel<<>>( - grad, input, output, scale, amax, params, N, num_aligned_elements); + grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize - unary_grad_kernel<1, true, fp32, Param, OP> - <<>>(grad, input, output, scale, amax, params, N, N); + unary_grad_kernel<1, true, fp32, Param, OP><<>>( + grad, input, output, scale, amax, scale_inv, params, N, N); break; } } @@ -379,8 +392,8 @@ template __launch_bounds__(unary_kernel_threads) __global__ void gated_act_kernel(const InputType *input, OutputType *output, const ComputeType *scale, - ComputeType *amax, const size_t m, const size_t n, const Param p, - const size_t num_aligned_elements) { + ComputeType *amax, ComputeType *scale_inv, const size_t m, const size_t n, + const Param p, const size_t num_aligned_elements) { const size_t M = num_aligned_elements * m; for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { const size_t id_x = tid % num_aligned_elements; @@ -389,7 +402,7 @@ __launch_bounds__(unary_kernel_threads) __global__ VectorizedLoader loader1(input + id_y * n * 2 + n, n); VectorizedStorer storer(output + id_y * n, n); ComputeType max = 0; - ComputeType s = 0; + ComputeType s = 1; if constexpr (is_fp8::value) { if (scale != nullptr) s = *scale; } @@ -412,12 +425,18 @@ __launch_bounds__(unary_kernel_threads) __global__ storer.store(id_x, n); if constexpr (is_fp8::value) { - /* warp tile amax reduce*/ - max = reduce_max(max, warp_id); + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); + } + } - if (threadIdx.x == 0 && amax != nullptr) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); } } } @@ -427,8 +446,8 @@ template void GatedActivationKernelLauncher(const InputType *input, OutputType *output, const fp32 *scale, - fp32 *amax, const size_t m, const size_t n, const Param &p, - cudaStream_t stream) { + fp32 *amax, fp32 *scale_inv, const size_t m, const size_t n, + const Param &p, cudaStream_t stream) { if (m != 0 && n != 0) { size_t num_aligned_elements = get_num_aligned_elements(input, n, nvec, sizeof(InputType)); constexpr size_t threads = unary_kernel_threads; @@ -439,18 +458,18 @@ void GatedActivationKernelLauncher(const InputType *input, OutputType *output, c switch (auto align = CheckAlignment(n, nvec, input, input + n, output)) { case Alignment::SAME_ALIGNED: gated_act_kernel - <<>>(input, output, scale, amax, m, n, p, + <<>>(input, output, scale, amax, scale_inv, m, n, p, num_aligned_elements); break; case Alignment::SAME_UNALIGNED: gated_act_kernel - <<>>(input, output, scale, amax, m, n, p, + <<>>(input, output, scale, amax, scale_inv, m, n, p, num_aligned_elements); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize gated_act_kernel<1, true, ComputeType, Param, Activation> - <<>>(input, output, scale, amax, m, n, p, n); + <<>>(input, output, scale, amax, scale_inv, m, n, p, n); break; } } diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index bcfc0c608d..6703ce728c 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -852,6 +852,11 @@ __device__ __forceinline__ void reciprocal(T *value_inv, const T value) { *value_inv = 1 / value; } +template <> +__device__ __forceinline__ void reciprocal(float *value_inv, const float value) { + *value_inv = __frcp_rn(value); +} + } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_ diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index dcd860c3a4..9b8279be25 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -43,6 +43,8 @@ class AttnMaskType(Enum): PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK PADDING_CAUSAL_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK + CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK + PADDING_CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK class QKVLayout(Enum): @@ -97,11 +99,21 @@ def canonicalize_attn_mask_type(attn_mask_type: str): return AttnMaskType.PADDING_MASK case "causal": return AttnMaskType.CAUSAL_MASK + case "causal_bottom_right" | "bottom_right_causal": + return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK case "padding_causal" | "causal_padding": return AttnMaskType.PADDING_CAUSAL_MASK + case ( + "padding_causal_bottom_right" + | "causal_padding_bottom_right" + | "bottom_right_causal_padding" + | "bottom_right_padding_causal" + ): + return AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK raise ValueError( - f"Unsupported {attn_mask_type=}, supported attn_mask_type=" - "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}" + f"Unsupported {attn_mask_type=}, supported attn_mask_type={{'no_mask', 'padding', 'causal'," + " 'padding_causal', 'causal_padding', 'causal_bottom_right'," + " 'padding_causal_bottom_right'}" ) @@ -155,6 +167,75 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout): return batch, q_max_seqlen, kv_max_seqlen +def _reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat, inverse: bool): + match tensor_format: + case QKVFormat.SBHD: + seq_dim = 0 + case QKVFormat.BSHD: + seq_dim = 1 + case _: + raise ValueError(f"{tensor_format=} is not supported for causal load balancing.") + + if cp_size == 1: + return tensor + + if cp_size % 2 != 0: + raise ValueError(f"{cp_size=} must be a multiple of 2.") + + # Need to ensure we have 2 pairs to swap for balancing between cp ranks + if tensor.shape[seq_dim] % (cp_size * 2) != 0: + raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}") + + # [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D] + # [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D] + ori_tensor_shape = tensor.shape + tensor = tensor.reshape( + ( + *ori_tensor_shape[:seq_dim], + 2 * cp_size, + ori_tensor_shape[seq_dim] // (2 * cp_size), + *ori_tensor_shape[seq_dim + 1 :], + ) + ) + + parts = [] + if not inverse: + for cp_rank in range(cp_size): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) + else: + for cp_rank in range(cp_size // 2): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + base = 4 * cp_rank + index = jnp.array([base, base + 2]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) + for cp_rank in range(cp_size // 2): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + base = 2 * cp_size - 1 - 4 * cp_rank + index = jnp.array([base, base - 2]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) + + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] + combined = jnp.stack(parts, axis=seq_dim) + + return combined.reshape(ori_tensor_shape) + + +def reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat): + """Reorders a tensor for load balancing the compute of causal attention.""" + return _reorder_causal_load_balancing(tensor, cp_size, tensor_format, False) + + +def inverse_reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat): + """Inverse operation of `reorder_causal_load_balancing`.""" + return _reorder_causal_load_balancing(tensor, cp_size, tensor_format, True) + + def fused_attn( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], @@ -166,6 +247,8 @@ def fused_attn( scaling_factor: float, dropout_probability: float, is_training: bool, + context_parallel_causal_load_balanced: bool = False, + context_parallel_axis: str = "", ): """ Perform non-THD (non-packed) cuDNN fused attention. @@ -192,6 +275,9 @@ def fused_attn( scaling_factor (float): Scaling factor for the attention scores. dropout_probability (float): Dropout probability to apply during attention. is_training (bool): Flag indicating whether the model is in training mode. + context_parallel_causal_load_balanced (bool): + Indicates the sequences are ordered for causal mask load balancing when running context parallelism. + context_parallel_axis (str): The name of the context parallel axis. Returns: (jnp.ndarray): The output tensor from the fused attention. """ @@ -213,7 +299,11 @@ def fused_attn( ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" # convert the mask to seqlens, mask doesn't support ragged offsets - if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: + if attn_mask_type in [ + AttnMaskType.NO_MASK, + AttnMaskType.CAUSAL_MASK, + AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, + ]: batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout) q_seq_lens = jnp.full((batch,), q_max_seqlen, dtype=jnp.int32) kv_seq_lens = jnp.full((batch,), kv_max_seqlen, dtype=jnp.int32) @@ -242,6 +332,8 @@ def fused_attn( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=1, + context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, + context_parallel_axis=context_parallel_axis, ) return output @@ -262,6 +354,8 @@ def fused_attn_thd( dropout_probability: float, is_training: bool, max_segments_per_seq: int = 1, + context_parallel_causal_load_balanced: bool = False, + context_parallel_axis: str = "", ): """ (Experimental) Perform THD (packed) cuDNN fused attention. @@ -300,6 +394,9 @@ def fused_attn_thd( Indicating the maximum number of segments inside a sequence. This parameter is to constrain the limit usage and need to be static during the e2e training. The XLA compile time and memory consumption is proportional to `max_segments_per_seq`. + context_parallel_causal_load_balanced (bool): + Indicates the sequences are ordered for causal mask load balancing when running context parallelism. + context_parallel_axis (str): The name of the context parallel axis. Returns: (jnp.ndarray): The output tensor from the fused attention. @@ -354,12 +451,14 @@ def fused_attn_thd( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, + context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, + context_parallel_axis=context_parallel_axis, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15)) def _fused_attn( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], @@ -375,6 +474,8 @@ def _fused_attn( dropout_probability: float, is_training: bool, max_segments_per_seq: int, + context_parallel_causal_load_balanced: bool, + context_parallel_axis: str, ): output, _ = _fused_attn_fwd_rule( qkv, @@ -391,6 +492,8 @@ def _fused_attn( dropout_probability, is_training, max_segments_per_seq, + context_parallel_causal_load_balanced, + context_parallel_axis, ) return output @@ -410,6 +513,8 @@ def _fused_attn_fwd_rule( dropout_probability, is_training, max_segments_per_seq, + context_parallel_causal_load_balanced, + context_parallel_axis, ): output, softmax_aux, rng_state = tex.fused_attn_fwd( qkv, @@ -426,6 +531,8 @@ def _fused_attn_fwd_rule( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, + context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, + context_parallel_axis=context_parallel_axis, ) output = checkpoint_name(output, "context") softmax_aux = checkpoint_name(softmax_aux, "context") @@ -451,6 +558,8 @@ def _fused_attn_bwd_rule( dropout_probability, is_training, max_segments_per_seq, + context_parallel_causal_load_balanced, + context_parallel_axis, ctx, dz, ): @@ -483,6 +592,8 @@ def _fused_attn_bwd_rule( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, + context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, + context_parallel_axis=context_parallel_axis, ) if attn_bias_type == AttnBiasType.NO_BIAS: grad_bias = None diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 56359646b1..47483c67ea 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -467,7 +467,7 @@ def sharded_impl(x, amax, scale, scale_inv): local_x, local_amax = ActLuFp8Primitive.impl( x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum ) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) return local_x, global_updated_amax diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 76ccec363b..d5b901c107 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -9,8 +9,9 @@ from typing import Optional, Tuple import warnings +import jax import jax.numpy as jnp -from jax import dtypes +from jax import dtypes, lax from jax.interpreters import mlir from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding @@ -34,7 +35,11 @@ get_cudnn_version, ) from ..sharding import ( + global_mesh_resource, + lax_paral_op, all_reduce_sum_along_dp_fsdp, + get_mesh_axis_size, + get_mesh_axis_rank, get_all_mesh_axes, num_of_devices, ) @@ -47,6 +52,38 @@ ] +@partial( + jax.tree_util.register_dataclass, + data_fields=[], + meta_fields=[ + "attn_bias_type", + "attn_mask_type", + "qkv_layout", + "scaling_factor", + "dropout_probability", + "is_training", + "max_segments_per_seq", + "context_parallel_load_balanced", + "cp_axis", + ], +) +@dataclass(frozen=True) +class _FusedAttnConfig: + """ + Passes static configuration properties of fused attention. + """ + + attn_bias_type: NVTE_Bias_Type + attn_mask_type: NVTE_Mask_Type + qkv_layout: NVTE_QKV_Layout + scaling_factor: float + dropout_probability: float + is_training: bool + max_segments_per_seq: int + context_parallel_load_balanced: bool + cp_axis: str + + @dataclass(frozen=True) class FusedAttnHelper: """ @@ -178,7 +215,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): name = "te_fused_attn_forward" multiple_results = True - impl_static_args = (9, 10, 11, 12, 13, 14, 15) + impl_static_args = (9,) inner_primitive = None outer_primitive = None @@ -194,13 +231,7 @@ def abstract( _k_seq_offsets, seed_aval, *, - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, + config: _FusedAttnConfig, ): """ Fused attention fwd abstract @@ -213,7 +244,7 @@ def abstract( assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( - FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) + FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) ) output_shape = (*batch_shape, q_max_seqlen, attn_heads, head_dim) @@ -223,10 +254,10 @@ def abstract( backend = FusedAttnHelper( q_dtype, k_dtype, - qkv_layout, - attn_bias_type, - attn_mask_type, - dropout_probability, + config.qkv_layout, + config.attn_bias_type, + config.attn_mask_type, + config.dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, @@ -238,7 +269,7 @@ def abstract( softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) softmax_dtype = q_dtype elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: - softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, max_segments_per_seq) + softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, config.max_segments_per_seq) softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: raise ValueError(f"Unsupported {backend=}") @@ -252,7 +283,7 @@ def abstract( rng_state_shape = (seed_aval.shape[0], checker.rng_state_size) rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype) - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: bias_batch = bias_heads = 0 else: *bias_batch_shape, bias_heads, _, _ = bias_aval.shape @@ -270,14 +301,14 @@ def abstract( num_gqa_groups, bias_heads, head_dim, - scaling_factor, - dropout_probability, - attn_bias_type, - attn_mask_type, - qkv_layout, + config.scaling_factor, + config.dropout_probability, + config.attn_bias_type, + config.attn_mask_type, + config.qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), - is_training, - max_segments_per_seq, + config.is_training, + config.max_segments_per_seq, ) wkspace_aval = q_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) @@ -308,28 +339,12 @@ def lowering( k_seq_offsets, seed, *, - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, + config: _FusedAttnConfig, ): """ Fused attention fwd lowering rules """ - operands = [ - q, - k, - v, - bias, - q_cu_seqlen, - kv_cu_seqlen, - q_seq_offsets, - k_seq_offsets, - seed, - ] + operands = [q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, q_seq_offsets, k_seq_offsets, seed] operand_shapes = map(lambda x: x.type.shape, operands) out_types = [ ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) @@ -340,12 +355,12 @@ def lowering( q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( - FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) + FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) ) input_batch = reduce(operator.mul, batch_shape) - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: bias_batch = bias_heads = 0 else: *bias_batch_shape, bias_heads, _, _ = bias_aval.shape @@ -362,16 +377,16 @@ def lowering( num_gqa_groups, bias_heads, head_dim, - max_segments_per_seq, + config.max_segments_per_seq, wkspace_aval.size, - scaling_factor, - dropout_probability, - attn_bias_type, - attn_mask_type, - qkv_layout, + config.scaling_factor, + config.dropout_probability, + config.attn_bias_type, + config.attn_mask_type, + config.qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - is_training, + config.is_training, not FusedAttnHelper.is_non_deterministic_allowed(), ) @@ -390,17 +405,11 @@ def impl( q_seq_offsets, k_seq_offsets, seed, - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, + config: _FusedAttnConfig, ): assert FusedAttnFwdPrimitive.inner_primitive is not None - if nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD: + if nvte_get_qkv_format(config.qkv_layout) == NVTE_QKV_Format.NVTE_THD: def _fix_len_take(x, condition, fill_value=-1): x_shape = x.shape @@ -418,7 +427,7 @@ def convert_to_2d(offsets, batch, max_seqlen): ) return offsets_2d - match qkv_layout: + match config.qkv_layout: case NVTE_QKV_Layout.NVTE_T3HD: kv_max_seqlen = q_max_seqlen = q.shape[-4] kv_batch = q_batch = reduce(operator.mul, q.shape[:-4]) @@ -472,66 +481,27 @@ def convert_to_2d(offsets, batch, max_seqlen): q_seq_offsets, k_seq_offsets, seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training, - max_segments_per_seq=max_segments_per_seq, + config=config, ) return output, softmax_aux, rng_state @staticmethod - def batcher( - batched_args, - batch_dims, - *, - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, - ): + def batcher(batched_args, batch_dims, *, config): check_valid_batch_dims(batch_dims) assert FusedAttnFwdPrimitive.outer_primitive is not None q_bdim, *_, seed_bdim = batch_dims out_bdims = q_bdim, q_bdim, seed_bdim return ( - FusedAttnFwdPrimitive.outer_primitive.bind( - *batched_args, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training, - max_segments_per_seq=max_segments_per_seq, - ), + FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, config=config), out_bdims, ) @staticmethod - def infer_sharding_from_operands( - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, - mesh, - arg_infos, - result_infos, - ): - del attn_bias_type, attn_mask_type, scaling_factor - del dropout_probability, is_training, max_segments_per_seq, result_infos + def infer_sharding_from_operands(config, mesh, arg_infos, result_infos): + del result_infos q_spec = get_padded_spec(arg_infos[0]) - k_spec = get_padded_spec(arg_infos[1]) - match qkv_layout: + match config.qkv_layout: case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD: # q_spec = (...batch, q_seqlen, head, hidden) out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:])) @@ -543,33 +513,22 @@ def infer_sharding_from_operands( # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden) out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-4]) + mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) ) case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD: # q_spec = (...batch, q_seqlen, head, hidden) # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden) out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3]) + mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) ) case _: - raise ValueError(f"Unsupported {qkv_layout=}") + raise ValueError(f"Unsupported {config.qkv_layout=}") rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) return (out_sharding, softmax_aux_sharding, rng_state_sharding) @staticmethod - def partition( - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, - mesh, - arg_infos, - result_infos, - ): + def partition(config, mesh, arg_infos, result_infos): out_sharding = result_infos[0].sharding softmax_aux_sharding = result_infos[1].sharding rng_state_sharding = seed_sharding = NamedSharding( @@ -577,16 +536,7 @@ def partition( ) arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) - impl = partial( - FusedAttnFwdPrimitive.impl, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training, - max_segments_per_seq=max_segments_per_seq, - ) + impl = partial(FusedAttnFwdPrimitive.impl, config=config) return mesh, impl, out_shardings, arg_shardings @@ -600,7 +550,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): name = "te_fused_attn_backward" multiple_results = True - impl_static_args = (12, 13, 14, 15, 16, 17, 18) + impl_static_args = (12,) inner_primitive = None outer_primitive = None @@ -619,13 +569,7 @@ def abstract( _q_seq_offsets, _k_seq_offsets, *, - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, + config, ): """ Fused attention bwd abstract @@ -641,10 +585,10 @@ def abstract( assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( - FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) + FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) ) - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: bias_batch = bias_heads = 0 else: *bias_batch_shape, bias_heads, _, _ = bias_aval.shape @@ -662,15 +606,15 @@ def abstract( num_gqa_groups, bias_heads, head_dim, - scaling_factor, - dropout_probability, - attn_bias_type, - attn_mask_type, - qkv_layout, + config.scaling_factor, + config.dropout_probability, + config.attn_bias_type, + config.attn_mask_type, + config.qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), - is_training, + config.is_training, deterministic, - max_segments_per_seq, + config.max_segments_per_seq, ) dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) @@ -707,13 +651,7 @@ def lowering( q_seq_offsets, k_seq_offsets, *, - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, + config, ): """ Fused attention bwd lowering rules @@ -743,12 +681,12 @@ def lowering( q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( - FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) + FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) ) input_batch = reduce(operator.mul, batch_shape) - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: bias_batch = bias_heads = 0 else: *bias_batch_shape, bias_heads, _, _ = bias_aval.shape @@ -765,16 +703,16 @@ def lowering( num_gqa_groups, bias_heads, head_dim, - max_segments_per_seq, + config.max_segments_per_seq, wkspace_aval.size, - scaling_factor, - dropout_probability, - attn_bias_type, - attn_mask_type, - qkv_layout, + config.scaling_factor, + config.dropout_probability, + config.attn_bias_type, + config.attn_mask_type, + config.qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - is_training, + config.is_training, not FusedAttnHelper.is_non_deterministic_allowed(), ) @@ -796,17 +734,11 @@ def impl( kv_seqlen, q_seq_offsets, k_seq_offsets, - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, + config, ): assert FusedAttnBwdPrimitive.inner_primitive is not None - if nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD: + if nvte_get_qkv_format(config.qkv_layout) == NVTE_QKV_Format.NVTE_THD: def _fix_len_take(x, condition, fill_value=-1): x_shape = x.shape @@ -825,7 +757,7 @@ def convert_to_2d(offsets, batch, max_seqlen): ) return offsets_2d - match qkv_layout: + match config.qkv_layout: case NVTE_QKV_Layout.NVTE_T3HD: kv_max_seqlen = q_max_seqlen = q.shape[-4] kv_batch = q_batch = reduce(operator.mul, q.shape[:-4]) @@ -882,63 +814,25 @@ def convert_to_2d(offsets, batch, max_seqlen): kv_cu_seqlen, q_seq_offsets, k_seq_offsets, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training, - max_segments_per_seq=max_segments_per_seq, + config=config, ) return dq, dk, dv, dbias @staticmethod - def batcher( - batched_args, - batch_dims, - *, - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, - ): + def batcher(batched_args, batch_dims, *, config): check_valid_batch_dims(batch_dims) assert FusedAttnBwdPrimitive.outer_primitive is not None q_bdim, k_bdim, v_bdim, *_ = batch_dims out_bdims = q_bdim, k_bdim, v_bdim, q_bdim return ( - FusedAttnBwdPrimitive.outer_primitive.bind( - *batched_args, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training, - max_segments_per_seq=max_segments_per_seq, - ), + FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config), out_bdims, ) @staticmethod - def infer_sharding_from_operands( - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, - mesh, - arg_infos, - result_infos, - ): - del attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, max_segments_per_seq - del dropout_probability, is_training, result_infos + def infer_sharding_from_operands(config, mesh, arg_infos, result_infos): + del config, result_infos q_spec = get_padded_spec(arg_infos[0]) k_spec = get_padded_spec(arg_infos[1]) v_spec = get_padded_spec(arg_infos[2]) @@ -950,18 +844,7 @@ def infer_sharding_from_operands( return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) @staticmethod - def partition( - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, - mesh, - arg_infos, - result_infos, - ): + def partition(config, mesh, arg_infos, result_infos): del result_infos q_spec = get_padded_spec(arg_infos[0]) k_spec = get_padded_spec(arg_infos[1]) @@ -1001,17 +884,11 @@ def sharded_impl( kv_cu_seqlen, q_seq_offsets, k_seq_offsets, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training, - max_segments_per_seq=max_segments_per_seq, + config=config, ) global_dbias = local_dbias - if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: - global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) + if config.attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: + global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) return local_dq, local_dk, local_dv, global_dbias return mesh, sharded_impl, out_shardings, arg_shardings @@ -1020,6 +897,378 @@ def sharded_impl( register_primitive(FusedAttnBwdPrimitive) +@dataclass(frozen=True) +class _FusedAttnCPWithAllGatherHelper: + """Helper class to assist with running the all-gather strategy for CP attention.""" + + mesh: jax.sharding.Mesh + config: _FusedAttnConfig + + def check_supported(self): + """Checks if the context parallel implementation is supported by the given arguments.""" + header = "Context parallel fused attention" + + allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD] + assert self.config.qkv_layout in allowed_layouts, ( + f"{header} only supports layouts: {','.join([str(x) for x in allowed_layouts])} got:" + f" {self.config.qkv_layout}" + ) + + assert ( + self.config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS + ), f"{header} does not support bias got: {self.config.attn_bias_type}" + + allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK] + assert self.config.attn_mask_type in allowed_masks, ( + f"{header} only supports masking types: " + f" {','.join([str(x) for x in allowed_masks])} got: {self.config.attn_mask_type}" + ) + + assert self.config.max_segments_per_seq == 1, ( + f"{header} only supports max_segments_per_seq == 1 got:" + f" {self.config.max_segments_per_seq}" + ) + assert self.config.dropout_probability == 0.0, f"{header} does not support dropout" + + def get_adjusted_mask(self): + """Converts the mask for context parallelism.""" + if self.config.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK: + return NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK + return self.config.attn_mask_type + + def all_gather_kv(self, k, v): + """Performs a all-gather of k and v over context parallel ranks.""" + + def ag(x): + return lax_paral_op( + x, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True + ) + + match self.config.qkv_layout: + case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: + return ag(k), v + case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: + return ag(k), ag(v) + + return k, v # fall through + + def reduce_scatter_dkv(self, dk, dv): + """Performs a reduce-scatter of dk and dv over context parallel ranks.""" + + def rs(x): + return lax_paral_op( + x, + lax.psum_scatter, + self.config.cp_axis, + mesh=self.mesh, + scatter_dimension=1, + tiled=True, + ) + + match self.config.qkv_layout: + case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: + return rs(dk), dv + case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: + return rs(dk), rs(dv) + + return dk, dv # fall through + + def kv_seqlens_for_rank(self, cp_rank, kv_max_seqlen, kv_seqlen_per_subrank): + """Returns sequence lengths of KV to use for each sub rank of the given cp_rank. + + Example: CP=4, MaxLen = 1024, Unbalanced + cp_rank 0: [128, 256] + cp_rank 1: [384, 512] + cp_rank 2: [640, 768] + cp_rank 3: [896, 1024] + + Example: CP=4, MaxLen = 1024, Balanced + cp_rank 0: [128, 1024] + cp_rank 1: [256, 896] + cp_rank 2: [384, 768] + cp_rank 3: [512, 640] + """ + if self.config.context_parallel_load_balanced: + kv_seq_this_rank = [ + (cp_rank + 1) * kv_seqlen_per_subrank, + kv_max_seqlen - cp_rank * kv_seqlen_per_subrank, + ] + else: + kv_seq_this_rank = [ + (cp_rank * 2 + 1) * kv_seqlen_per_subrank, + (cp_rank * 2 + 2) * kv_seqlen_per_subrank, + ] + return kv_seq_this_rank + + def slice_kv(self, k, v, slice_seq_len): + """Slices k and v tensors to a sequence length of slice_seq_len.""" + + def sliced(x): + return lax.dynamic_slice_in_dim(x, 0, slice_seq_len, axis=1) + + match self.config.qkv_layout: + case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: + return sliced(k), v + case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: + return sliced(k), sliced(v) + + return k, v # fall through + + def pad_kv(self, dk, dv, pad_seq_len): + """Pads dk and dv tensors to a sequence length of pad_seq_len.""" + + def pad(x, npad): + return jnp.pad(x, npad, "constant", constant_values=0.0) + + match self.config.qkv_layout: + case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: + npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0], [0, 0]] + return pad(dk, npad), dv + case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: + npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0]] + return pad(dk, npad), pad(dv, npad) + + return dk, dv # fall through + + +class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): + """ + Fused Attention Forward with Context Parallelism Primitive + + This context parallel implementation uses all-gather to collect KV inputs from context parallel ranks. + """ + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + # Call base implementation for non-context parallel mesh to avoid unecessary work. + is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 + if not is_context_parallel: + return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos) + + helper = _FusedAttnCPWithAllGatherHelper(mesh, config) + helper.check_supported() + + out_sharding = result_infos[0].sharding + softmax_aux_sharding = result_infos[1].sharding + rng_state_sharding = seed_sharding = NamedSharding( + mesh, PartitionSpec(get_all_mesh_axes(), None) + ) + arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) + out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) + + def impl(q, k, v, bias, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets, seed): + + cp_size = get_mesh_axis_size(config.cp_axis, mesh) + cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) + + # cuDNN does not support right-aligned masking with dynamic sequence length padding. + # Therefore we must explicitly instantiate each CP rank slicing and use a runtime switch + # to select the appropriate computation. Each case generates a [..., SEQ/CP, ..] tensor + # meeting the expectation of the SPMD model. + # TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding + # mask/sequence length tensor to avoid this unrolled loop. + def _cross_attn(idx, q, k, v, bias, q_seqlen, kv_seqlen, seed): + kv_max_seqlen = k.shape[1] + kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2) + assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size" + + q_split = jnp.split(q, 2, axis=1) + + kv_seqlens_for_rank = helper.kv_seqlens_for_rank( + idx, kv_max_seqlen, kv_seqlen_per_subrank + ) + + results = [] + for sub_idx in range(2): + if config.attn_mask_type == NVTE_Mask_Type.NVTE_NO_MASK: + k_unmasked, v_unmasked = k, v # full kv used for unmasked + else: + k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx]) + + q_seqlen_for_step = q_seqlen / (cp_size * 2) + num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx] + kv_seqlen_for_step = (kv_seqlen / (cp_size * 2)) * num_kv_chunks + + output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl( + q_split[sub_idx], + k_unmasked, + v_unmasked, + bias, + q_seqlen_for_step, + kv_seqlen_for_step, + q_seq_offsets, + k_seq_offsets, + seed, + config=config, + ) + results.append((output, softmax_aux, rng_state)) + + output = jnp.concatenate((results[0][0], results[1][0]), axis=1) + softmax_aux = jnp.concatenate((results[0][1], results[1][1]), axis=2) + rng_state = results[1][2] # Use the final RNG state + + return output, softmax_aux, rng_state + + k_ag, v_ag = helper.all_gather_kv(k, v) + + functions = [ + partial(_cross_attn, idx, q, k_ag, v_ag, bias, q_seqlen, kv_seqlen, seed) + for idx in range(cp_size) + ] + + return lax.switch(cp_rank, functions) + + return mesh, impl, out_shardings, arg_shardings + + +register_primitive(FusedAttnCPWithAllGatherFwdPrimitive) + + +class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): + """ + Fused Attention Backward with Context Parallelism Primitive. + + This context parallel implementation uses all-gather to collect KV and dKV inputs from context parallel ranks. + The gradients are subsequently reduce-scattered back to each context parallel rank. + """ + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + # Call base implementation for non-context parallel mesh to avoid unecessary work. + is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 + if not is_context_parallel: + return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos) + + # Ensure we can support this configuration with context parallelism. + helper = _FusedAttnCPWithAllGatherHelper(mesh, config) + helper.check_supported() + + del result_infos + q_spec = get_padded_spec(arg_infos[0]) + k_spec = get_padded_spec(arg_infos[1]) + v_spec = get_padded_spec(arg_infos[2]) + bias_spec = get_padded_spec(arg_infos[3]) + dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) + dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) + dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) + + def impl( + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + ): + cp_size = get_mesh_axis_size(config.cp_axis, mesh) + cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) + + # See comment in FusedAttnCPFwdPrimitive.partition for why we define this function. + def _cross_attn_bwd( + idx, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen + ): + kv_max_seqlen = k.shape[1] + kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2) + assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size" + + q_split = jnp.split(q, 2, axis=1) + output_split = jnp.split(output, 2, axis=1) + doutput_split = jnp.split(doutput, 2, axis=1) + softmax_aux_split = jnp.split(softmax_aux, 2, axis=2) + + kv_seqlens_for_rank = helper.kv_seqlens_for_rank( + idx, kv_max_seqlen, kv_seqlen_per_subrank + ) + + results = [] + for sub_idx in range(2): + if config.attn_mask_type == NVTE_Mask_Type.NVTE_NO_MASK: + k_unmasked, v_unmasked = k, v # full kv used for unmasked + else: + k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx]) + + q_seqlen_for_step = q_seqlen // (cp_size * 2) + num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx] + kv_seqlen_for_step = (kv_seqlen // (cp_size * 2)) * num_kv_chunks + + dq_local, dk_local, dv_local, dbias_local = FusedAttnBwdPrimitive.impl( + q_split[sub_idx], + k_unmasked, + v_unmasked, + bias, + softmax_aux_split[sub_idx], + rng_state, + output_split[sub_idx], + doutput_split[sub_idx], + q_seqlen_for_step, + kv_seqlen_for_step, + q_seq_offsets, + k_seq_offsets, + config=config, + ) + + # pad dk/dv to be unsliced shape so we can reduce scatter over all ranks. + if config.attn_mask_type != NVTE_Mask_Type.NVTE_NO_MASK: + pad_length = kv_max_seqlen - kv_seqlens_for_rank[sub_idx] + dk_local, dv_local = helper.pad_kv(dk_local, dv_local, pad_length) + + results.append((dq_local, dk_local, dv_local, dbias_local)) + + dq_local = jnp.concatenate((results[0][0], results[1][0]), axis=1) + dk_local_pad = results[0][1] + results[1][1] + dv_local_pad = results[0][2] + results[1][2] + return dq_local, dk_local_pad, dv_local_pad, results[1][3] + + k_ag, v_ag = helper.all_gather_kv(k, v) + + functions = [ + partial( + _cross_attn_bwd, + idx, + q, + k_ag, + v_ag, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + ) + for idx in range(cp_size) + ] + + dq, dk_local, dv_local, dbias = lax.switch(cp_rank, functions) + dk, dv = helper.reduce_scatter_dkv(dk_local, dv_local) + + return dq, dk, dv, dbias + + return mesh, impl, out_shardings, arg_shardings + + +register_primitive(FusedAttnCPWithAllGatherBwdPrimitive) + + +def _maybe_context_parallel_axis(cp_axis: str): + if not cp_axis: + gmr = global_mesh_resource() + if gmr is not None: + cp_axis = gmr.cp_resource + else: + cp_axis = "" + return cp_axis + + def fused_attn_fwd( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], @@ -1035,6 +1284,8 @@ def fused_attn_fwd( dropout_probability: float, is_training: bool, max_segments_per_seq: int, + context_parallel_causal_load_balanced: bool = False, + context_parallel_axis: str = "", ) -> jnp.ndarray: """ Perform the forward pass of with cuDNN fused attention implementations. @@ -1063,6 +1314,9 @@ def fused_attn_fwd( scaling_factor (float): Scaling factor for the attention scores. dropout_probability (float): Dropout probability to apply during attention. is_training (bool): Flag indicating whether the model is in training mode. + context_parallel_causal_load_balanced (bool): + Indicates the sequences are ordered for causal mask load balancing when running context parallelism. + context_parallel_axis (str): The name of the context parallel axis. Returns: (jnp.ndarray): The output tensor from the fused attention. """ @@ -1094,14 +1348,7 @@ def fused_attn_fwd( assert bias is None bias = jnp.zeros(0, dtype=qkv[0].dtype) - return FusedAttnFwdPrimitive.outer_primitive.bind( - *qkv_for_primitive, - bias, - q_seqlen, - kv_seqlen, - q_seq_offsets if is_ragged else _not_used, - kv_seq_offsets if is_ragged else _not_used, - seed, + fused_config = _FusedAttnConfig( attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, qkv_layout=qkv_layout, @@ -1109,6 +1356,19 @@ def fused_attn_fwd( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, + context_parallel_load_balanced=context_parallel_causal_load_balanced, + cp_axis=_maybe_context_parallel_axis(context_parallel_axis), + ) + + return FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive.bind( + *qkv_for_primitive, + bias, + q_seqlen, + kv_seqlen, + q_seq_offsets if is_ragged else _not_used, + kv_seq_offsets if is_ragged else _not_used, + seed, + config=fused_config, ) @@ -1130,6 +1390,8 @@ def fused_attn_bwd( dropout_probability: float, is_training: bool, max_segments_per_seq: int, + context_parallel_causal_load_balanced: bool = False, + context_parallel_axis: str = "", ): """ Perform the backward pass of the cuDNN fused attention implementations. @@ -1159,7 +1421,9 @@ def fused_attn_bwd( scaling_factor (float): Scaling factor for the attention scores. dropout_probability (float): Dropout probability to apply during attention. is_training (bool): Flag indicating whether the model is in training mode. - + context_parallel_causal_load_balanced (bool): + Indicates the sequences are ordered for causal mask load balancing when running context parallelism. + context_parallel_axis (str): The name of the context parallel axis. Returns: Tuple[jnp.ndarray, ...], jnp.ndarray: - The first tuple contains the gradients with respect to the input `qkv` tensors in the @@ -1194,7 +1458,19 @@ def fused_attn_bwd( assert bias is None bias = jnp.zeros(0, dtype=qkv[0].dtype) - *qkv_grads, bias_grad = FusedAttnBwdPrimitive.outer_primitive.bind( + fused_config = _FusedAttnConfig( + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + max_segments_per_seq=max_segments_per_seq, + context_parallel_load_balanced=context_parallel_causal_load_balanced, + cp_axis=_maybe_context_parallel_axis(context_parallel_axis), + ) + + *qkv_grads, bias_grad = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive.bind( *qkv_for_primitive, bias, softmax_aux, @@ -1205,12 +1481,6 @@ def fused_attn_bwd( kv_seqlen, q_seq_offsets if is_ragged else _not_used, kv_seq_offsets if is_ragged else _not_used, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training, - max_segments_per_seq=max_segments_per_seq, + config=fused_config, ) return tuple(qkv_grads[: len(qkv)]), bias_grad diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index caf9272b02..e85f28a06a 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -533,8 +533,8 @@ def sharded_impl(dz, x, mu, rsigma, gamma): local_dx, local_dgamma, local_dbeta = LayerNormBwdPrimitive.impl( dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon ) - global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma) - global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta) + global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma, mesh) + global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta, mesh) return local_dx, global_dgamma, global_dbeta return mesh, sharded_impl, out_shardings, arg_shardings @@ -935,7 +935,7 @@ def partition(epsilon, mesh, arg_infos, result_infos): def sharded_impl(dz, x, rsigma, gamma): local_dx, local_dgamma = RmsNormBwdPrimitive.impl(dz, x, rsigma, gamma, epsilon=epsilon) - global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma) + global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma, mesh) return local_dx, global_dgamma return mesh, sharded_impl, out_shardings, arg_shardings @@ -1228,7 +1228,7 @@ def sharded_impl(x, gamma, beta, amax, scale, scale_inv): zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, ) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) return local_x, local_mu, local_rsigma, global_updated_amax @@ -1481,7 +1481,7 @@ def sharded_impl(x, gamma, amax, scale, scale_inv): local_x, local_rsigma, local_amax = RmsNormFwdFp8Primitive.impl( x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon ) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) return local_x, local_rsigma, global_updated_amax diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 2c529e71c8..48bf4d969a 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -157,7 +157,7 @@ def sharded_impl(x, amax, scale, scale_inv): local_cx, local_updated_amax = CastFP8Primitive.impl( x, amax, scale, scale_inv, out_dtype=out_dtype ) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh) return local_cx, global_updated_amax diff --git a/transformer_engine/jax/cpp_extensions/transpose.py b/transformer_engine/jax/cpp_extensions/transpose.py index e503792dc0..963d7f09e8 100644 --- a/transformer_engine/jax/cpp_extensions/transpose.py +++ b/transformer_engine/jax/cpp_extensions/transpose.py @@ -390,7 +390,7 @@ def sharded_impl(x, amax, scale, scale_inv): static_axis_boundary=static_axis_boundary, transpose_axis_boundary=transpose_axis_boundary, ) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh) return local_cx, local_cxt, global_updated_amax @@ -646,8 +646,8 @@ def sharded_impl(dz, amax, scale, scale_inv): static_axis_boundary=static_axis_boundary, transpose_axis_boundary=transpose_axis_boundary, ) - global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) return local_out, local_t_out, global_dbias, global_updated_amax return mesh, sharded_impl, out_shardings, arg_shardings @@ -981,8 +981,8 @@ def sharded_impl(dz, x, amax, scale, scale_inv): act_enum=act_enum, ) ) - global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) return local_out, local_t_out, global_dbias, global_updated_amax return mesh, sharded_impl, out_shardings, arg_shardings @@ -1225,7 +1225,7 @@ def sharded_impl(dz, x, amax, scale, scale_inv): static_axis_boundary=static_axis_boundary, act_enum=act_enum, ) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) return local_out, local_t_out, global_updated_amax return mesh, sharded_impl, out_shardings, arg_shardings diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 0a2172bb1b..14f449a76b 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -100,7 +100,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) - .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK); + .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) + .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) + .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", + NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index c0b60fe61e..a14a8384cf 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -20,6 +20,7 @@ BATCH_AXES = "nvte_batch" SEQLEN_AXES = "nvte_seqlen" SEQLEN_TP_AXES = "nvte_seqlen_tp" +SEQLEN_CP_AXES = "nvte_seqlen_cp" HEAD_AXES = "nvte_head" HIDDEN_AXES = "nvte_hidden" HIDDEN_TP_AXES = "nvte_hidden_tp" @@ -30,8 +31,7 @@ W_JOINED_AXES = "nvte_w_joined" -def _get_mesh_info(resource: str): - mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh +def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh): assert resource in mesh.axis_names, f"{resource} is not in the axis_names of Mesh {mesh}." return mesh.shape[resource], resource @@ -66,6 +66,7 @@ def get_sharding_map_logic_axis_to_mesh_axis(): BATCH_AXES: batch_dim_rule, SEQLEN_AXES: None, SEQLEN_TP_AXES: gsr.tp_resource, + SEQLEN_CP_AXES: gsr.cp_resource, HEAD_AXES: gsr.tp_resource, HIDDEN_AXES: None, HIDDEN_TP_AXES: gsr.tp_resource, @@ -132,13 +133,15 @@ def get_padded_spec(spec, ndim): return spec + (None,) * (ndim - len(spec)) -def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str): +def lax_paral_op( + x: jnp.array, ops: Callable, mesh_resource: str, mesh: jax.sharding.Mesh, **kwargs +): """ A wrapper function to invoke lax.p* operations, like psum. """ if mesh_resource is not None: - _, resource = _get_mesh_info(mesh_resource) - return ops(x, resource) + _, resource = _get_mesh_info(mesh_resource, mesh) + return ops(x, resource, **kwargs) return x @@ -149,6 +152,33 @@ def num_of_devices(): return len(jax.devices()) +def get_mesh_axis_size(axis, mesh=None): + """ + Get the axis size of the given mesh. + If the mesh is None, it would be replaced + by the global mesh. + """ + if mesh is None: + mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh + + if axis is None: + return 1 + + assert axis in mesh.shape, f"{axis} is not a axis of the given mesh {mesh.shape}" + return mesh.shape[axis] + + +def get_mesh_axis_rank(axis: str, mesh=None): + """ + Gets the local axis rank of the `axis` of the array. + If the mesh is None the rank is 0. + """ + if mesh is None: + return 0 + _, axis_name = _get_mesh_info(axis, mesh) + return jax.lax.axis_index(axis_name) + + @dataclass class MeshResource: """ @@ -169,12 +199,16 @@ class MeshResource: pp_resource : str, default = None The axis name in Mesh used to split model layers. along. If it is None, then pipeline parallelism is disabled. + cp_resource : str, default = None + The axis name in Mesh used to split sequence (context) dimensions along + in the attention. If it is None, then context parallelism is disabled. """ dp_resource: str = None tp_resource: str = None fsdp_resource: str = None pp_resource: str = None + cp_resource: str = None _GLOBAL_MESH_RESOURCE = MeshResource() @@ -201,22 +235,22 @@ def global_mesh_resource() -> MeshResource: return _GLOBAL_MESH_RESOURCE -def all_reduce_sum_along_dp_fsdp(x: jnp.array): +def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh): """ All-Reduce (Sum) along DP and FSDP mesh axes. """ - x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource) - return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource) + x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource, mesh) + return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh) -def all_reduce_max_along_all_axes_except_PP(x: jnp.array): +def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mesh): """ All-Reduce (Max) along all mesh axes. """ all_axes = get_all_mesh_axes() for axis in all_axes: if axis != global_mesh_resource().pp_resource: - x = lax_paral_op(x, jax.lax.pmax, axis) + x = lax_paral_op(x, jax.lax.pmax, axis, mesh) return x diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 07ade71905..c4097333d3 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -67,12 +67,14 @@ def _load_library(): from transformer_engine.pytorch.module import LayerNorm from transformer_engine.pytorch.module import RMSNorm from transformer_engine.pytorch.module import GroupedLinear +from transformer_engine.pytorch.module import Fp8Padding, Fp8Unpadding from transformer_engine.pytorch.module import initialize_ub from transformer_engine.pytorch.module import destroy_ub from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import InferenceParams from transformer_engine.pytorch.attention import MultiheadAttention from transformer_engine.pytorch.transformer import TransformerLayer +from transformer_engine.pytorch.permutation import moe_permute, moe_unpermute from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.fp8 import fp8_model_init from transformer_engine.pytorch.graph import make_graphed_callables diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index ff121527d3..bacadf2cd5 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -6,6 +6,7 @@ import collections from contextlib import nullcontext from importlib.metadata import version as get_pkg_version +from importlib.metadata import PackageNotFoundError import math import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -37,8 +38,20 @@ AttnBiasType, AttnMaskType, FusedAttnBackend, + META_QKV, + META_DQKV, + META_O, + META_DO, + META_S, + META_DP, + META_O_CP, + META_DQKV_CP, +) +from transformer_engine.pytorch.fp8 import ( + FP8GlobalStateManager, + get_fp8_te_dtype, + get_fp8_torch_dtype, ) -from transformer_engine.pytorch.fp8 import get_fp8_te_dtype from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.module import LayerNormLinear, Linear from transformer_engine.pytorch.module.base import TransformerEngineBaseModule @@ -73,31 +86,52 @@ from transformer_engine.pytorch.graph import is_graph_capturing +_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) +_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) +_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) _flash_attn_version = PkgVersion(get_pkg_version("flash-attn")) _flash_attn_version_required = PkgVersion("2.0.6") -_flash_attn_max_version = PkgVersion("2.5.8") +_flash_attn_max_version = PkgVersion("2.6.3") _flash_attn_2_plus = _flash_attn_version >= PkgVersion("2") _flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1") _flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3") _flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4") _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") _flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7") +_flash_attn_3_plus = False +_use_flash_attn_3 = False +try: + _flash_attn_v3_version = PkgVersion(get_pkg_version("flashattn-hopper")) + _flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.6.1") +except PackageNotFoundError: + if get_device_compute_capability() == (9, 0) and _NVTE_FLASH_ATTN: + warnings.warn( + "To use flash-attn v3, please use the following commands to install: \n" + """(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" \n""" + """(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` \n""" + """(3) mkdir -p $python_path/flashattn_hopper \n""" + """(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py""" + ) +else: + from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3 + from flashattn_hopper.flash_attn_interface import ( + flash_attn_varlen_func as flash_attn_varlen_func_v3, + ) + from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import + _flash_attn_forward as _flash_attn_forward_v3, + ) + from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import + _flash_attn_backward as _flash_attn_backward_v3, + ) + + _use_flash_attn_3 = True if _flash_attn_version >= _flash_attn_version_required: - from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func + from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd -META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT -META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1 -META_O = tex.FP8FwdTensors.GEMM2_INPUT -META_DO = tex.FP8BwdTensors.GRAD_INPUT2 -META_S = tex.FP8FwdTensors.GEMM3_OUTPUT -META_DP = tex.FP8BwdTensors.GRAD_INPUT3 -# repurpose some unused amax history buffers for partial results of CP fwd and bwd -META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT -META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1 # NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) @@ -110,10 +144,6 @@ _stream_handler = logging.StreamHandler() _stream_handler.setFormatter(_formatter) -_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) -_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) -_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) - _attention_backends = { "attention_params": None, "use_flash_attention": None, @@ -318,6 +348,7 @@ def get_attention_backend( use_fused_attention = False # Filter: Compute capability + global _flash_attn_3_plus, _use_flash_attn_3 if device_compute_capability < (8, 0): if use_flash_attention: logger.debug("Disabling FlashAttention as it requires compute capability sm80+") @@ -325,32 +356,42 @@ def get_attention_backend( if use_fused_attention: logger.debug("Disabling FusedAttention as it requires compute capability sm80+") use_fused_attention = False + if device_compute_capability < (9, 0): + if use_flash_attention and _flash_attn_3_plus: + logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+") + _use_flash_attn_3 = False # Filter: Data type - if use_flash_attention and ( - qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type == Float8Tensor - ): - logger.debug( - "Disabling FlashAttention due to unsupported QKV data type. " - "Supported: qkv_type = torch.Tensor, qkv_dtype = {torch.bfloat16, torch.float16}. " - "Found: qkv_type = %s, qkv_dtype = %s.", - qkv_type, - qkv_dtype, - ) - use_flash_attention = False - if use_fused_attention and (qkv_dtype not in [torch.bfloat16, torch.float16]): - logger.debug( - "Disabling FusedAttention due to unsupported QKV data type. " - "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. " - "Found: qkv_dtype = %s.", - qkv_dtype, - ) - use_fused_attention = False + if qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type not in [ + torch.Tensor, + Float8Tensor, + ]: + if use_flash_attention: + logger.debug( + "Disabling FlashAttention due to unsupported QKV data type. " + "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. " + "Found: qkv_dtype = %s.", + qkv_dtype, + ) + use_flash_attention = False + if use_fused_attention: + logger.debug( + "Disabling FusedAttention due to unsupported QKV data type. " + "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. " + "Found: qkv_dtype = %s.", + qkv_dtype, + ) + use_fused_attention = False # Filter: Execution type if fp8 and fp8_meta["recipe"].fp8_dpa: - if use_flash_attention: - logger.debug("Disabling FlashAttention as it does not support FP8") + if use_flash_attention and not _use_flash_attn_3: + logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8") + use_flash_attention = False + if use_flash_attention and _use_flash_attn_3 and is_training: + logger.debug( + "Disabling FlashAttention as FlashAttention 3 does not support FP8 training" + ) use_flash_attention = False if use_unfused_attention: logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8") @@ -396,6 +437,12 @@ def get_attention_backend( ) use_flash_attention = False + # Filter: Dropout + if attention_dropout != 0.0 and use_flash_attention: + if _flash_attn_3_plus and _use_flash_attn_3: + logger.debug("Disabling FlashAttention 3 for dropout") + _use_flash_attn_3 = False + # Filter: Context parallelism # qkv_format | attn_mask_type | attn_bias_type | supported backends # ---------------------------------------------------------------------------------------------------- @@ -414,6 +461,14 @@ def get_attention_backend( ) use_unfused_attention = False if context_parallel and use_flash_attention: + if _flash_attn_3_plus and _use_flash_attn_3: + logger.debug("Disabling FlashAttention 3 for context parallelism") + _use_flash_attn_3 = False + if fp8 and fp8_meta["recipe"].fp8_dpa: + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with FP8" + ) + use_flash_attention = False if "bottom_right" in attn_mask_type: logger.debug( "Disabling FlashAttention as it does not support context parallelism with" @@ -439,6 +494,7 @@ def get_attention_backend( " bias for THD format" ) use_flash_attention = False + if context_parallel and use_fused_attention: if "bottom_right" in attn_mask_type: logger.debug( @@ -498,6 +554,18 @@ def get_attention_backend( if use_fused_attention: logger.debug("Disabling FusedAttention for arbitrary mask") use_fused_attention = False + if ( + use_flash_attention + and _flash_attn_3_plus + and attn_mask_type in ["causal", "padding_causal"] + and max_seqlen_q != max_seqlen_kv + ): + logger.warning( + "Disabling FlashAttention 3 as it only supports bottom-right-diagonal " + "causal mask since flash-attn 2.1. See " + "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" + ) + _use_flash_attn_3 = False if ( use_flash_attention and _flash_attn_2_1_plus @@ -546,12 +614,6 @@ def get_attention_backend( "with causal mask, no dropout, and qkv_format = bshd/sbhd" ) use_fused_attention = False - elif context_parallel: - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention " - "with context parallelism" - ) - use_fused_attention = False elif max_seqlen_q != max_seqlen_kv and attn_mask_type in [ "no_mask", "padding", @@ -571,6 +633,15 @@ def get_attention_backend( attn_mask_type, ) use_fused_attention = False + if ( + use_flash_attention + and (window_size[0] != -1 or window_size[1] not in [-1, 0]) + and _flash_attn_3_plus + ): + logger.debug( + "Disabling FlashAttention 3 as it does not support sliding window attention" + ) + _use_flash_attn_3 = False if ( use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]) @@ -590,6 +661,14 @@ def get_attention_backend( # | | bottom_right (converts to a 'post_scale_bias' bias) # UnfusedDotProductAttention | no_bias, pre/post_scale_bias | # | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias + if use_flash_attention and core_attention_bias_type == "alibi": + if _flash_attn_3_plus and _use_flash_attn_3: + logger.debug("Disabling FlashAttention 3 for ALiBi") + _use_flash_attn_3 = False + if not _flash_attn_2_4_plus: + logger.debug("Disabling FlashAttention for ALiBi") + use_flash_attention = False + if use_flash_attention and ( core_attention_bias_type not in ["no_bias", "alibi"] or core_attention_bias_shape is not None @@ -749,6 +828,21 @@ def get_attention_backend( ) use_flash_attention = False + # Select FusedAttention for FP8 + # FA3 uses default scaling factors (i.e. 1) in FP8 execution, while FusedAttention takes + # scaling factors from `fp8_meta` and offers more accurate quantization/de-quantization + if ( + use_flash_attention + and use_fused_attention + and fused_attention_backend == FusedAttnBackend["FP8"] + and _use_flash_attn_3 + ): + logger.debug( + "Disabling FlashAttention 3 to give FusedAttention preference as FusedAttention " + "supports more accurate scaling factors in FP8 execution" + ) + use_flash_attention = False + # Selected backend if use_flash_attention: use_fused_attention = False @@ -1071,7 +1165,7 @@ def _get_full_cu_seqlens( return _cu_seqlens_cache[(batch_size, max_seqlen)] -@jit_fuser +@torch.compile def pack_tensor( indices: torch.Tensor, tensor: torch.Tensor, @@ -1082,14 +1176,19 @@ def pack_tensor( padding_indice = torch.zeros( 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device ) - tensor = torch.cat((tensor, padding_indice), dim=0) - indices = indices.repeat(1, tensor.shape[1], tensor.shape[2]) - packed = torch.gather(tensor, 0, indices) + if isinstance(tensor, Float8Tensor): + tensor_data = torch.cat((tensor._data, padding_indice), dim=0) + + packed = Float8Tensor.make_like(tensor, data=torch.gather(tensor_data, 0, indices)) + else: + tensor = torch.cat((tensor, padding_indice), dim=0) + + packed = torch.gather(tensor, 0, indices) return packed -@jit_fuser +@torch.compile def pack_2_tensors( indices: torch.Tensor, t1: torch.Tensor, @@ -1103,7 +1202,7 @@ def pack_2_tensors( return t1_packed, t2_packed -@jit_fuser +@torch.compile def pack_3_tensors( indices: torch.Tensor, t1: torch.Tensor, @@ -1119,7 +1218,7 @@ def pack_3_tensors( return t1_packed, t2_packed, t3_packed -@jit_fuser +@torch.compile def unpack_tensor( indices: torch.Tensor, dim0: int, @@ -1132,12 +1231,16 @@ def unpack_tensor( unpacked = torch.zeros( dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device ) - unpacked.scatter_(0, indices, tensor) - unpacked = unpacked[0:-1, :, :] + if isinstance(tensor, Float8Tensor): + unpacked.scatter_(0, indices, tensor._data) + unpacked = Float8Tensor.make_like(tensor, data=unpacked[0:-1, :, :]) + else: + unpacked.scatter_(0, indices, tensor) + unpacked = unpacked[0:-1, :, :] return unpacked -@jit_fuser +@torch.compile def unpack_2_tensors( indices: torch.Tensor, dim0: int, @@ -1152,7 +1255,7 @@ def unpack_2_tensors( return t1_unpacked, t2_unpacked -@jit_fuser +@torch.compile def unpack_3_tensors( indices: torch.Tensor, dim0: int, @@ -1320,9 +1423,6 @@ def forward( cu_seqlens_q_padded, cu_seqlens_kv_padded, dropout_p, - cp_group, - cp_global_ranks, - cp_stream, softmax_scale, qkv_format, attn_mask_type, @@ -1332,6 +1432,9 @@ def forward( use_fused_attention, fp8, fp8_meta, + cp_group, + cp_global_ranks, + cp_stream, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -1440,10 +1543,14 @@ def forward( for x in [k_f16, v_f16] ] fp8_meta_kwargs = {} - fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv[META_QKV] - fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv[META_S] - fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale[META_S] - fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale[META_O_CP] + fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv + fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV + fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv + fp8_meta_kwargs["d_scale_s_offset"] = META_S + fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale + fp8_meta_kwargs["q_scale_s_offset"] = META_S + fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale + fp8_meta_kwargs["q_scale_o_offset"] = META_O_CP amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) else: assert False, "FP8 is only supported with Fused Attention!" @@ -1495,8 +1602,10 @@ def forward( fp8_dtype_forward, ) if fp8 and use_fused_attention: - fp8_meta_kwargs["amax_s"] = amax_per_step[0][i] - fp8_meta_kwargs["amax_o"] = amax_per_step[1][i] + fp8_meta_kwargs["amax_s"] = amax_per_step + fp8_meta_kwargs["amax_s_offset"] = i + fp8_meta_kwargs["amax_o"] = amax_per_step + fp8_meta_kwargs["amax_o_offset"] = cp_size + i if causal: if i == 0: if pad_between_seqs_q: @@ -2831,10 +2940,10 @@ def backward(ctx, dout): None, None, None, + attn_dbias, None, None, None, - attn_dbias, None, None, None, @@ -2843,30 +2952,56 @@ def backward(ctx, dout): @torch.compile -def get_seq_chunk_ids_to_all_gathered_kv( - local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size_left, device +def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous): + """ + Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing. + To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks + before or after CP communications (e.g., all-gather, all-to-all). This function is to compute + sequence chunk ids for reordering. + """ + chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device) + if to_contiguous: + for rank in range(cp_size): + chunk_ids[rank] = 2 * rank + chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1 + else: + for rank in range(cp_size): + chunk_ids[2 * rank] = rank + chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1 + return chunk_ids + + +def get_kv_seq_info_after_all_gather( + local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size, causal ): - """Compute sequence chunk ids to the all-gathered KV.""" - seq_end_idx = (local_chunk_id + 1) * max_seqlen_kv - seq_start_idx = max(0, seq_end_idx - max_seqlen_q - window_size_left) - seqlen = seq_end_idx - seq_start_idx - num_chunks = (seqlen + max_seqlen_kv - 1) // max_seqlen_kv - chunk_ids = torch.arange( - local_chunk_id - num_chunks + 1, - local_chunk_id + 1, - dtype=torch.int32, - device=device, - ) - chunk_ids_to_all_gathered_kv = torch.where( - chunk_ids < cp_size, 2 * chunk_ids, 2 * (2 * cp_size - chunk_ids) - 1 - ) - return chunk_ids_to_all_gathered_kv + """Compute KV sequence index range and update window size after all-gather.""" + local_chunk_end_idx = (local_chunk_id + 1) * max_seqlen_kv + full_seq_end_idx = max_seqlen_kv * cp_size * 2 + + if window_size is None: + window_size = (-1, 0) if causal else (-1, -1) + + if window_size[1] == -1: + seq_end_idx = full_seq_end_idx + window_size_right = -1 + else: + seq_end_idx = min(full_seq_end_idx, local_chunk_end_idx + window_size[1]) + window_size_right = local_chunk_end_idx + window_size[1] - seq_end_idx + + if window_size[0] == -1: + seq_start_idx = 0 + window_size_left = -1 + else: + seq_start_idx = max(0, local_chunk_end_idx - max_seqlen_q - window_size[0]) + window_size_left = window_size[0] + seq_end_idx - local_chunk_end_idx + + return (seq_start_idx, seq_end_idx), (window_size_left, window_size_right) class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): """ - Attention implementation with context parallelism. - KV all-gather between CP ranks is exposed. + Attention implementation with context parallelism. KV all-gather between CP ranks is exposed. + Refer section 3.3.2 of `The Llama 3 Herd of Models `_. """ @staticmethod @@ -2877,14 +3012,10 @@ def forward( k, v, cu_seqlens_q, - cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, cu_seqlens_q_padded, - cu_seqlens_kv_padded, dropout_p, - cp_group, - cp_stream, softmax_scale, qkv_format, attn_mask_type, @@ -2893,6 +3024,8 @@ def forward( deterministic, use_fused_attention, window_size, + cp_group, + cp_stream, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -2902,10 +3035,9 @@ def forward( causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type - assert causal and not padding, f"{attn_mask_type} mask type is not supported!" + assert not padding, f"{attn_mask_type} mask type is not supported!" if use_fused_attention and causal and "bottom_right" not in attn_mask_type: attn_mask_type = attn_mask_type + "_bottom_right" - assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" assert ( @@ -2914,6 +3046,8 @@ def forward( fa_optional_forward_kwargs = {} if _flash_attn_2_4_plus: fa_optional_forward_kwargs["alibi_slopes"] = None + if _flash_attn_2_5_7_plus: + fa_optional_forward_kwargs["block_table"] = None assert qkv_format != "thd", f"{qkv_format} format is not supported!" qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format @@ -2926,31 +3060,35 @@ def forward( max_seqlen_q = max_seqlen_q // (2 * cp_size) max_seqlen_kv = max_seqlen_kv // (2 * cp_size) cu_seqlens_q = cu_seqlens_q // (2 * cp_size) - cu_seqlens_kv = cu_seqlens_kv // (2 * cp_size) cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size) - cu_seqlens_kv_padded = cu_seqlens_kv_padded // (2 * cp_size) - - if causal: - if qkv_format == "bshd": - # [b, s, np, hn] -> [b, 2, s//2, np, hn] - q = q.view(q.shape[0], 2, q.shape[1] // 2, *q.shape[2:]) - # [b, s, np, hn] -> [s, b, np, hn] - k, v = [x.transpose(0, 1).contiguous() for x in [k, v]] - elif qkv_format == "sbhd": - # [s, b, np, hn] -> [2, s//2, b, np, hn] - q = q.view(2, q.shape[0] // 2, *q.shape[1:]) - # create two streams to resolve wave quantization issue of Flash Attn in each step - flash_attn_streams = [torch.cuda.current_stream(), cp_stream] + # [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn] + q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) + # [b, s, np, hn] or [s, b, np, hn] -> [s, b, np, hn] + k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] + # [s, b, np, hn] -> [cp, s, b, np, hn] k_ag, _ = gather_along_first_dim(k, cp_group) v_ag, _ = gather_along_first_dim(v, cp_group) - cp_stream.wait_stream(torch.cuda.current_stream()) + + # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True) + k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) + v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) + # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] + k_ag = k_ag.view(-1, *k.shape[1:]) + v_ag = v_ag.view(-1, *v.shape[1:]) + cp_stream.wait_stream(torch.cuda.current_stream()) + + # create two streams to resolve wave quantization issue of Flash Attn in each step + flash_attn_streams = [torch.cuda.current_stream(), cp_stream] local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] - chunk_ids_to_kv_ag_per_step = [None, None] + kv_seq_range_per_step = [None, None] + window_size_per_step = [None, None] + cu_seqlens_kv_per_step = [None, None] out_per_step = [None, None] softmax_lse_per_step = [None, None] rng_states = [None, None] @@ -2959,53 +3097,36 @@ def forward( for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - chunk_ids_to_kv_ag = get_seq_chunk_ids_to_all_gathered_kv( - local_seq_chunk_ids[i], - cp_size, - max_seqlen_q, - max_seqlen_kv, - ( - max_seqlen_kv * cp_size * 2 - if (window_size is None or window_size[0] == -1) - else window_size[0] - ), - k.device, - ) - chunk_ids_to_kv_ag_per_step[i] = chunk_ids_to_kv_ag - num_kv_chunks = chunk_ids_to_kv_ag.numel() - if qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - q_ = q[:, i].contiguous() - # [num_kv_chunks, sq//2, b, np, hn] -> [b, num_kv_chunks*sq//2, np, hn] - k_ = ( - torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag) - .movedim(2, 0) - .contiguous() - .view(k.shape[1], -1, *k.shape[-2:]) - ) - v_ = ( - torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag) - .movedim(2, 0) - .contiguous() - .view(v.shape[1], -1, *v.shape[-2:]) - ) - elif qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] - q_ = q[i].contiguous() - # [num_kv_chunks, sq//2, b, np, hn] -> [num_kv_chunks*sq//2, b, np, hn] - k_ = torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag).view( - -1, *k.shape[-3:] - ) - v_ = torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag).view( - -1, *v.shape[-3:] + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + q_ = q.select(seq_dim, i).contiguous() + kv_seq_range_per_step[i], window_size_per_step[i] = ( + get_kv_seq_info_after_all_gather( + local_seq_chunk_ids[i], + cp_size, + max_seqlen_q, + max_seqlen_kv, + window_size, + causal, ) + ) + seq_start_idx, seq_end_idx = ( + kv_seq_range_per_step[i][0], + kv_seq_range_per_step[i][1], + ) + max_seqlen_kv_ = seq_end_idx - seq_start_idx + cu_seqlens_kv_per_step[i] = _get_full_cu_seqlens( + k.shape[1], max_seqlen_kv_, k.device + ) + k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] + # [s_range, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn] + k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] if use_fused_attention: out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd( is_training, max_seqlen_q, - max_seqlen_kv * num_kv_chunks, + max_seqlen_kv_, cu_seqlens_q, - cu_seqlens_kv * num_kv_chunks, + cu_seqlens_kv_per_step[i], q_, k_, v_, @@ -3018,8 +3139,8 @@ def forward( attn_bias_type=attn_bias_type, attn_bias=attn_bias, cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded * num_kv_chunks, - window_size=window_size, + cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], + window_size=window_size_per_step[i], ) else: q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]] @@ -3029,14 +3150,14 @@ def forward( k_, v_, cu_seqlens_q, - cu_seqlens_kv * num_kv_chunks, + cu_seqlens_kv_per_step[i], max_seqlen_q, - max_seqlen_kv * num_kv_chunks, + max_seqlen_kv_, dropout_p, softmax_scale, - causal=True, + causal=causal, return_softmax=False, - window_size=window_size, + window_size=window_size_per_step[i], **fa_optional_forward_kwargs, ) ) @@ -3044,9 +3165,9 @@ def forward( if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): if qkv_format == "bshd": - out[:, i - 1].copy_(out_per_step[i - 1].view_as(out[:, i - 1])) + out[:, i - 1].copy_(out_per_step[i - 1].view(out[:, i - 1].shape)) elif qkv_format == "sbhd": - out[i - 1].copy_(out_per_step[i - 1].view_as(out[i - 1])) + out[i - 1].copy_(out_per_step[i - 1].view(out[i - 1].shape)) torch.cuda.current_stream().wait_stream(cp_stream) @@ -3063,26 +3184,24 @@ def forward( k, v, cu_seqlens_q, - cu_seqlens_kv, cu_seqlens_q_padded, - cu_seqlens_kv_padded, - *chunk_ids_to_kv_ag_per_step, + *cu_seqlens_kv_per_step, *out_per_step, *softmax_lse_per_step, *rng_states, ) + ctx.kv_seq_range_per_step = kv_seq_range_per_step + ctx.window_size_per_step = window_size_per_step ctx.cp_group = cp_group ctx.cp_stream = cp_stream ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_kv = max_seqlen_kv ctx.softmax_scale = softmax_scale ctx.qkv_format = qkv_format - ctx.attn_mask_type = attn_mask_type ctx.attn_bias_type = attn_bias_type + ctx.attn_mask_type = attn_mask_type ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention - ctx.window_size = window_size return out @staticmethod @@ -3090,21 +3209,20 @@ def backward(ctx, dout): cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) - (q, k, v, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ( - ctx.saved_tensors[:7] - ) - chunk_ids_to_kv_ag_per_step = ctx.saved_tensors[7:9] - out_per_step = ctx.saved_tensors[9:11] - softmax_lse_per_step = ctx.saved_tensors[11:13] - rng_states = ctx.saved_tensors[13:15] + (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = ctx.saved_tensors[:5] + cu_seqlens_kv_per_step = ctx.saved_tensors[5:7] + out_per_step = ctx.saved_tensors[7:9] + softmax_lse_per_step = ctx.saved_tensors[9:11] + rng_states = ctx.saved_tensors[11:13] + kv_seq_range_per_step = ctx.kv_seq_range_per_step + window_size_per_step = ctx.window_size_per_step + seq_dim = ctx.qkv_format.index("s") qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format - dout = dout.view_as(q) + dout = dout.view(q.shape) dq = torch.empty_like(q) - dk = torch.zeros( - (2 * cp_size, k.shape[0] // 2, *k.shape[1:]), dtype=k.dtype, device=k.device - ) + dk = torch.zeros((k.shape[0] * cp_size, *k.shape[1:]), dtype=k.dtype, device=k.device) dv = torch.zeros_like(dk) dq_per_step = [None, None] dk_per_step = [None, None] @@ -3115,11 +3233,20 @@ def backward(ctx, dout): # synchronize dkv update across steps dkv_update_done = torch.cuda.Event() + # [s, b, np, hn] -> [cp, s, b, np, hn] k_ag, _ = gather_along_first_dim(k, ctx.cp_group) v_ag, _ = gather_along_first_dim(v, ctx.cp_group) - ctx.cp_stream.wait_stream(torch.cuda.current_stream()) + + # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True) + k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) + v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) + # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] + k_ag = k_ag.view(-1, *k.shape[1:]) + v_ag = v_ag.view(-1, *v.shape[1:]) + ctx.cp_stream.wait_stream(torch.cuda.current_stream()) local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] @@ -3132,66 +3259,46 @@ def backward(ctx, dout): for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - chunk_ids_to_kv_ag = chunk_ids_to_kv_ag_per_step[i] - num_kv_chunks = chunk_ids_to_kv_ag.numel() + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + q_ = q.select(seq_dim, i).contiguous() + seq_start_idx, seq_end_idx = ( + kv_seq_range_per_step[i][0], + kv_seq_range_per_step[i][1], + ) + max_seqlen_kv = seq_end_idx - seq_start_idx + k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] + # [cp*s, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn] + k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] out_ = out_per_step[i] - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - q_ = q[:, i].contiguous() - # [num_kv_chunks, sq//2, b, np, hn] -> [b, num_kv_chunks*sq//2, np, hn] - k_ = ( - torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag) - .movedim(2, 0) - .contiguous() - .view(k.shape[1], -1, *k.shape[-2:]) - ) - v_ = ( - torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag) - .movedim(2, 0) - .contiguous() - .view(v.shape[1], -1, *v.shape[-2:]) - ) - dout_ = dout[:, i].contiguous().view_as(out_) - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] - q_ = q[i].contiguous() - # [num_kv_chunks, sq//2, b, np, hn] -> [num_kv_chunks*sq//2, b, np, hn] - k_ = torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag).view( - -1, *k.shape[-3:] - ) - v_ = torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag).view( - -1, *v.shape[-3:] - ) - dout_ = dout[i].contiguous().view_as(out_) + dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) if ctx.use_fused_attention: - dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ - torch.empty_like(x) for x in [q_, k_, v_] - ] aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]] dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd( ctx.max_seqlen_q, - ctx.max_seqlen_kv * num_kv_chunks, + max_seqlen_kv, cu_seqlens_q, - cu_seqlens_kv * num_kv_chunks, + cu_seqlens_kv_per_step[i], q_, k_, v_, out_, dout_, TE_DType[q.dtype], - TE_DType[k.dtype], + TE_DType[dout.dtype], aux_ctx_tensors, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded * num_kv_chunks, + cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, qkv_layout=qkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, - window_size=ctx.window_size, + window_size=window_size_per_step[i], + deterministic=ctx.deterministic, ) else: + batch_size = k_.shape[0] q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]] dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ torch.empty_like(x) for x in [q_, k_, v_] @@ -3207,65 +3314,64 @@ def backward(ctx, dout): dk_per_step[i], dv_per_step[i], cu_seqlens_q, - cu_seqlens_kv * num_kv_chunks, + cu_seqlens_kv_per_step[i], ctx.max_seqlen_q, - ctx.max_seqlen_kv * num_kv_chunks, + max_seqlen_kv, ctx.dropout_p, ctx.softmax_scale, - True, - window_size=ctx.window_size, + "causal" in ctx.attn_mask_type, + window_size=window_size_per_step[i], rng_state=rng_states[i], **fa_optional_backward_kwargs, ) + # [b*sq//2, np, hn] -> [b, sq//2, np, hn] + dq_per_step[i] = dq_per_step[i].view(dq[:, i].shape) + # [b*s_range, np, hn] -> [b, s_range, np, hn] + dk_per_step[i], dv_per_step[i] = [ + x.view(batch_size, -1, *x.shape[-2:]) + for x in [dk_per_step[i], dv_per_step[i]] + ] if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): - chunk_ids_to_kv_ag = chunk_ids_to_kv_ag_per_step[i - 1] - num_kv_chunks = chunk_ids_to_kv_ag.numel() if ctx.qkv_format == "bshd": - dq[:, i - 1].copy_(dq_per_step[i - 1].view_as(dq[:, i - 1])) - dk_per_step[i - 1] = ( - dk_per_step[i - 1] - .view(k.shape[1], num_kv_chunks, -1, *k.shape[-2:]) - .movedim(0, 2) - .contiguous() - ) - dv_per_step[i - 1] = ( - dv_per_step[i - 1] - .view(v.shape[1], num_kv_chunks, -1, *v.shape[-2:]) - .movedim(0, 2) - .contiguous() - ) + dq[:, i - 1].copy_(dq_per_step[i - 1]) elif ctx.qkv_format == "sbhd": - dq[i - 1].copy_(dq_per_step[i - 1].view_as(dq[i - 1])) - dk_per_step[i - 1] = dk_per_step[i - 1].view( - num_kv_chunks, -1, *k.shape[-3:] - ) - dv_per_step[i - 1] = dv_per_step[i - 1].view( - num_kv_chunks, -1, *v.shape[-3:] - ) - + dq[i - 1].copy_(dq_per_step[i - 1]) + # [b, s_range, np, hn] or [s_range, b, np, hn] -> [s_range, b, np, hn] + dk_per_step[i - 1], dv_per_step[i - 1] = [ + x.movedim(seq_dim, 0).contiguous() + for x in [dk_per_step[i - 1], dv_per_step[i - 1]] + ] # wait until dkv update of last step is done if i > 1: flash_attn_streams[i - 1].wait_event(dkv_update_done) - dk.index_add_(0, chunk_ids_to_kv_ag, dk_per_step[i - 1]) - dv.index_add_(0, chunk_ids_to_kv_ag, dv_per_step[i - 1]) + seq_start_idx, seq_end_idx = ( + kv_seq_range_per_step[i - 1][0], + kv_seq_range_per_step[i - 1][1], + ) + dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1]) + dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1]) if i < len(local_seq_chunk_ids): flash_attn_streams[i - 1].record_event(dkv_update_done) torch.cuda.current_stream().wait_stream(ctx.cp_stream) + # [cp*s, b, np, hn] -> [cp*2, s//2, b, np, hn] + dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) + dv = dv.view(2 * cp_size, -1, *dv.shape[-3:]) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, dk.device, False) + dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag) + dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag) + # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] dk = dk.view(-1, *dk.shape[-3:]) dv = dv.view(-1, *dv.shape[-3:]) dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group) - if ctx.qkv_format == "bshd": - dq = dq.view(dq.shape[0], -1, *dq.shape[-2:]) - dk = dk.transpose(0, 1).contiguous() - dv = dv.transpose(0, 1).contiguous() - elif ctx.qkv_format == "sbhd": - dq = dq.view(-1, *dq.shape[-3:]) + dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :]) + dk = dk.movedim(0, seq_dim).contiguous() + dv = dv.movedim(0, seq_dim).contiguous() return ( None, @@ -3287,120 +3393,647 @@ def backward(ctx, dout): None, None, None, - None, - None, ) -def attn_forward_func_with_cp( - is_training, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - dropout_p, - cp_group, - cp_global_ranks, - cp_stream, - cp_comm_type, - softmax_scale=None, - qkv_format="bshd", - attn_mask_type="causal", - attn_bias_type="no_bias", - attn_bias=None, - deterministic=False, - use_fused_attention=False, - window_size=None, - fp8=False, - fp8_meta=None, -) -> torch.Tensor: +@torch.compile +def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn): + """Reorder sequence chunk for A2A communication.""" + if before_attn: + # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn] or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn] + x = x.movedim(0, seq_dim).contiguous() + # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :]) + # reorder the sequence chunks + x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a) + else: + # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.movedim(seq_dim, 0).contiguous() + # reorder the sequence chunks + x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a) + # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn] + x = x.view(cp_size, 2, *x.shape[1:]) + return x + + +def flash_attn_a2a_communicate( + a2a_inputs: Union[torch.Tensor, List[torch.Tensor]], + chunk_ids_for_a2a: torch.Tensor, + seq_dim: int, + cp_size: int, + cp_group: dist_group_type, + cp_stream: torch.cuda.Stream, + before_attn: bool, +) -> Union[torch.Tensor, List[torch.Tensor]]: + """A2A communication for context parallelism.""" + a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs + a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs) + if before_attn: + for i in range(len(a2a_inputs) + 2): + if 0 < i < len(a2a_inputs) + 1: + a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1]) + a2a_reqs[i - 1] = torch.distributed.all_to_all_single( + a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True + ) + if i > 1: + with torch.cuda.stream(cp_stream): + a2a_reqs[i - 2].wait() + x = a2a_outputs[i - 2] + # reorder the sequence chunks + x = reorder_seq_chunks_for_a2a( + x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn + ) + # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn] + a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) + if i < len(a2a_inputs): + x = a2a_inputs[i] + # [b, s, np, hn] -> [b, s, cp, np//cp, hn] or [s, b, np, hn] -> [s, b, cp, np//cp, hn] + x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1]) + # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn] + a2a_inputs[i] = x.movedim(-3, 0).contiguous() + else: + for i in range(len(a2a_inputs) + 2): + if 0 < i < len(a2a_inputs) + 1: + a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1]) + a2a_reqs[i - 1] = torch.distributed.all_to_all_single( + a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True + ) + if i < len(a2a_inputs): + x = a2a_inputs[i] + # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) + # reorder the sequence chunks + a2a_inputs[i] = reorder_seq_chunks_for_a2a( + x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn + ) + if i > 1: + with torch.cuda.stream(cp_stream): + a2a_reqs[i - 2].wait() + x = a2a_outputs[i - 2] + # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] + x = x.movedim(0, -3).movedim(0, seq_dim).contiguous() + # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] + a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1]) + torch.cuda.current_stream().wait_stream(cp_stream) + return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs + + +class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): """ - Attention implementation with context parallelism. + Attention implementation with context parallelism. Like Ulysses, applying A2A to QKVO. + Refer the paper `DeepSpeed Ulysses `_. """ - assert qkv_format in [ - "bshd", - "sbhd", - "thd", - ], f"QKV format of {qkv_format} is not supported with context parallelism!" - assert ( - qkv_format != "sbhd" or use_fused_attention - ), "FlashAttention does not support sbhd format!" - assert ( - qkv_format != "thd" - or not use_fused_attention - or attn_mask_type in ["padding", "padding_causal"] - ), ( - f"Context parallelism is not supported for {attn_mask_type} mask type and " - f"{qkv_format} format with {'FusedAttention' if use_fused_attention else 'FlashAttention'}!" - ) - assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), ( - """Attention bias is only supported with FusedAttention and "causal" """ - """or "no_mask" mask types!""" - ) - assert ( - cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None - ), "cu_seqlens_q_padded and cu_seqlens_kv_padded cannot be None with context parallelism!" - - sliding_window_attn = ( - window_size is not None and window_size != (-1, 0) and window_size != (-1, -1) - ) - - if sliding_window_attn or cp_comm_type == "all_gather": - out = AttnFuncWithCPAndKVAllGather.apply( - is_training, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - dropout_p, - cp_group, - cp_stream, - softmax_scale, - qkv_format, - attn_mask_type, - attn_bias_type, - attn_bias, - deterministic, - use_fused_attention, - window_size, + @staticmethod + def forward( + ctx, + is_training, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + dropout_p, + softmax_scale, + qkv_format, + attn_mask_type, + attn_bias_type, + attn_bias, + deterministic, + use_fused_attention, + window_size, + fp8, + fp8_meta, + cp_group, + cp_stream, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + cp_size = get_distributed_world_size(cp_group) + + causal = "causal" in attn_mask_type + padding = "padding" in attn_mask_type + assert not padding, f"{attn_mask_type} mask type is not supported!" + assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" + assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" + assert ( + window_size == (-1, 0) + or window_size == (-1, -1) + or use_fused_attention + or _flash_attn_2_3_plus + ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" + fa_optional_forward_kwargs = {} + if _flash_attn_2_3_plus: + fa_optional_forward_kwargs["window_size"] = window_size + if _flash_attn_2_4_plus: + fa_optional_forward_kwargs["alibi_slopes"] = None + if _flash_attn_2_5_7_plus: + fa_optional_forward_kwargs["block_table"] = None + + assert ( + q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0 + ), "The number of attention heads needs to be divisible by CP size!" + + assert qkv_format != "thd", f"{qkv_format} format is not supported!" + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + + batch_dim = qkv_format.index("b") + seq_dim = qkv_format.index("s") + assert ( + q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 + ), "Sequence length per GPU needs to be divisible by 2!" + + if fp8: + if use_fused_attention: + fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + fused_attn_qkv_dtype = fp8_dtype_forward + fused_attn_backend = FusedAttnBackend["FP8"] + if fp8_meta["recipe"].fp8_mha: + assert ( + isinstance(q, Float8Tensor) + and isinstance(k, Float8Tensor) + and isinstance(v, Float8Tensor) + ), "q/k/v must be Float8Tensors for FP8 MHA!" + fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv + q_fp8, k_fp8, v_fp8 = q, k, v + q, k, v = q_fp8._data, k_fp8._data, v_fp8._data + elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q_f16, k_f16, v_f16 = q, k, v + q, k, v = [ + cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + for x in [q_f16, k_f16, v_f16] + ] + fp8_meta_kwargs = {} + fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv + fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV + fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv + fp8_meta_kwargs["d_scale_s_offset"] = META_S + fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale + fp8_meta_kwargs["q_scale_s_offset"] = META_S + fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale + fp8_meta_kwargs["q_scale_o_offset"] = META_O + fp8_meta_kwargs["amax_s"] = fp8_meta["scaling_fwd"].amax_history + fp8_meta_kwargs["amax_s_offset"] = META_S + fp8_meta_kwargs["amax_o"] = fp8_meta["scaling_fwd"].amax_history + fp8_meta_kwargs["amax_o_offset"] = META_O + else: + assert False, "FP8 is only supported with Fused Attention!" + else: + if use_fused_attention: + fp8_meta_kwargs = {} + fused_attn_qkv_dtype = TE_DType[q.dtype] + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, True) + q, k, v = flash_attn_a2a_communicate( + [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True ) - elif cp_comm_type == "p2p": - out = AttnFuncWithCPAndKVP2P.apply( - is_training, - q, - k, - v, + + if fp8 and not fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q_f16, k_f16, v_f16 = q, k, v + q, k, v = [ + cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + for x in [q_f16, k_f16, v_f16] + ] + + batch_size = q.shape[batch_dim] + if use_fused_attention: + out, aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + q, + k, + v, + fused_attn_qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + window_size=window_size, + **fp8_meta_kwargs, + ) + else: + # [b, cp*s, np//cp, hn] -> [b*cp*s, np//cp, hn] + q, k, v = [x.view(-1, *x.shape[-2:]) for x in [q, k, v]] + ( + _, + _, + _, + _, + out, + softmax_lse, + _, + rng_state, + ) = _flash_attn_forward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + dropout_p, + softmax_scale, + causal=causal, + return_softmax=False, + **fa_optional_forward_kwargs, + ) + aux_ctx_tensors = [softmax_lse, rng_state] + # [b*cp*s, np//cp, hn] -> [b, cp*s, np//cp, hn] + out = out.view(batch_size, -1, *out.shape[-2:]) + + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, False) + out = flash_attn_a2a_communicate( + out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False + ) + + if use_fused_attention: + if qkv_format == "bshd": + # [b*s, np, hn] -> [b, s, np, hn] + out = out.view(batch_size, -1, *out.shape[-2:]) + elif qkv_format == "sbhd": + # [s*b, np, hn] -> [s, b, np, hn] + out = out.view(-1, batch_size, *out.shape[-2:]) + + if fp8: + if fp8_meta["recipe"].fp8_mha: + out_fp8 = Float8Tensor( + data=out, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=META_O, + fp8_dtype=fp8_dtype_forward, + dtype=q_fp8.dtype, + ) + out = out_fp8._data + out_ret = out_fp8 + else: + out_f16 = cast_from_fp8( + out, + fp8_meta["scaling_fwd"], + META_O, + fp8_dtype_forward, + TE_DType[q_f16.dtype], + ) + out_ret = out_f16 + else: + out_ret = out + + if fp8: + if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q_save, k_save, v_save, out_save = q, k, v, out + elif fp8_meta["recipe"].fp8_mha: + q_fp8, k_fp8, v_fp8 = [ + Float8Tensor( + data=x, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=META_QKV, + fp8_dtype=fp8_dtype_forward, + dtype=out_fp8.dtype, + ) + for x in [q, k, v] + ] + q_save, k_save, v_save, out_save = q_fp8, k_fp8, v_fp8, out_fp8 + else: + q_save, k_save, v_save, out_save = q_f16, k_f16, v_f16, out_f16 + else: + q_save, k_save, v_save, out_save = q, k, v, out + + if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone() + fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone() + else: + fp8_fwd_scales, fp8_fwd_scale_invs = None, None + + ctx.save_for_backward( + q_save, + k_save, + v_save, + out_save, cu_seqlens_q, cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, - dropout_p, - cp_group, - cp_global_ranks, - cp_stream, - softmax_scale, - qkv_format, - attn_mask_type, - attn_bias_type, - attn_bias, - deterministic, - use_fused_attention, - fp8, - fp8_meta, + fp8_fwd_scales, + fp8_fwd_scale_invs, + *aux_ctx_tensors, ) + ctx.batch_size = batch_size + ctx.cp_group = cp_group + ctx.cp_stream = cp_stream + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_kv = max_seqlen_kv + ctx.softmax_scale = softmax_scale + ctx.qkv_format = qkv_format + ctx.attn_mask_type = attn_mask_type + ctx.attn_bias_type = attn_bias_type + ctx.deterministic = deterministic + ctx.window_size = window_size + ctx.use_fused_attention = use_fused_attention + ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ctx.fp8_meta = fp8_meta + return out_ret + + @staticmethod + def backward(ctx, dout): + cp_size = get_distributed_world_size(ctx.cp_group) + + q, k, v, out = ctx.saved_tensors[:4] + cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = ctx.saved_tensors[ + 4:8 + ] + fp8_fwd_scales, fp8_fwd_scale_invs = ctx.saved_tensors[8:10] + aux_ctx_tensors = ctx.saved_tensors[10:] + + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + causal = "causal" in ctx.attn_mask_type + seq_dim = ctx.qkv_format.index("s") + + if ctx.fp8: + if ctx.use_fused_attention: + fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) + fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) + fused_attn_qkv_dtype = fp8_dtype_forward + fused_attn_dqkv_dtype = fp8_dtype_backward + fused_attn_backend = FusedAttnBackend["FP8"] + if ctx.fp8_meta["recipe"].fp8_mha: + assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" + ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv + dout_fp8 = dout + dout = dout_fp8._data + else: + dout_f16 = dout + dout = cast_to_fp8( + dout_f16, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward + ) + fp8_meta_kwargs = {} + fp8_meta_kwargs["d_scale_qkv"] = fp8_fwd_scale_invs[META_QKV] + fp8_meta_kwargs["d_scale_s"] = fp8_fwd_scale_invs[META_S] + fp8_meta_kwargs["d_scale_o"] = fp8_fwd_scale_invs[META_O] + fp8_meta_kwargs["d_scale_do"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] + fp8_meta_kwargs["d_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP] + fp8_meta_kwargs["q_scale_s"] = fp8_fwd_scales[META_S] + fp8_meta_kwargs["q_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale[META_DP] + fp8_meta_kwargs["q_scale_dqkv"] = ctx.fp8_meta["scaling_bwd"].scale[META_DQKV] + fp8_meta_kwargs["amax_dp"] = ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP] + fp8_meta_kwargs["amax_dqkv"] = ctx.fp8_meta["scaling_bwd"].amax_history[0][ + META_DQKV + ] + else: + assert False, "FP8 is only supported with Fused Attention!" + else: + if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha: + assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" + q, k, v, out, dout = [x.from_float8(x.dtype) for x in [q, k, v, out, dout]] + if ctx.use_fused_attention: + fp8_meta_kwargs = {} + fused_attn_qkv_dtype = TE_DType[q.dtype] + fused_attn_dqkv_dtype = TE_DType[dout.dtype] + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + + if not ctx.use_fused_attention: + out = out.view(ctx.batch_size, -1, *out.shape[-2:]) + dout = dout.view(*out.shape) + + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, True) + out, dout = flash_attn_a2a_communicate( + [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True + ) + + fa_optional_backward_kwargs = {} + if _flash_attn_2_3_plus: + fa_optional_backward_kwargs["window_size"] = ctx.window_size + if _flash_attn_2_4_plus: + fa_optional_backward_kwargs["alibi_slopes"] = None + if _flash_attn_2_4_1_plus: + fa_optional_backward_kwargs["deterministic"] = ctx.deterministic + + if ctx.use_fused_attention: + dq, dk, dv, _ = fused_attn_bwd( + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + q, + k, + v, + out, + dout, + fused_attn_qkv_dtype, + fused_attn_dqkv_dtype, + aux_ctx_tensors, + fused_attn_backend, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + attn_scale=ctx.softmax_scale, + dropout=ctx.dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=ctx.attn_mask_type, + attn_bias_type=ctx.attn_bias_type, + window_size=ctx.window_size, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, + ) + else: + softmax_lse, rng_state = aux_ctx_tensors + out, dout = [x.view(-1, *x.shape[-2:]) for x in [out, dout]] + dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] + _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_kv, + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + ctx.dropout_p, + ctx.softmax_scale, + causal, + rng_state=rng_state, + **fa_optional_backward_kwargs, + ) + dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] + + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, False) + dq, dk, dv = flash_attn_a2a_communicate( + [dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, False + ) + + if ctx.qkv_format == "bshd": + dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] + elif ctx.qkv_format == "sbhd": + dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] + + if ctx.fp8: + if ctx.fp8_meta["recipe"].fp8_mha: + dq, dk, dv = [ + Float8Tensor( + data=x, + fp8_meta=ctx.fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=META_DQKV, + fp8_dtype=fp8_dtype_backward, + dtype=dout_fp8.dtype, + ) + for x in [dq, dk, dv] + ] + else: + dq, dk, dv = [ + cast_from_fp8( + x, + ctx.fp8_meta["scaling_bwd"], + META_DQKV, + fp8_dtype_backward, + TE_DType[dout_f16.dtype], + ) + for x in [dq, dk, dv] + ] + + return ( + None, + dq, + dk, + dv, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def attn_forward_func_with_cp( + is_training, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + dropout_p, + cp_group, + cp_global_ranks, + cp_stream, + cp_comm_type, + softmax_scale=None, + qkv_format="bshd", + attn_mask_type="causal", + attn_bias_type="no_bias", + attn_bias=None, + deterministic=False, + use_fused_attention=False, + window_size=None, + fp8=False, + fp8_meta=None, +) -> torch.Tensor: + """ + Attention implementation with context parallelism. + """ + + assert qkv_format in [ + "bshd", + "sbhd", + "thd", + ], f"QKV format of {qkv_format} is not supported with context parallelism!" + assert ( + qkv_format != "sbhd" or use_fused_attention + ), "FlashAttention does not support sbhd format!" + assert ( + qkv_format != "thd" + or not use_fused_attention + or attn_mask_type in ["padding", "padding_causal"] + ), ( + f"Context parallelism is not supported for {attn_mask_type} mask type and " + f"{qkv_format} format with {'FusedAttention' if use_fused_attention else 'FlashAttention'}!" + ) + assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), ( + """Attention bias is only supported with FusedAttention and "causal" """ + """or "no_mask" mask types!""" + ) + assert ( + cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None + ), "cu_seqlens_q_padded and cu_seqlens_kv_padded cannot be None with context parallelism!" + + sliding_window_attn = ( + window_size is not None and window_size != (-1, 0) and window_size != (-1, -1) + ) + assert ( + not sliding_window_attn + or cp_comm_type == "a2a" + or (cp_comm_type == "all_gather" and not use_fused_attention) + ), "The context parallel running configs cannot support sliding window attetnion!" + + args = [ + is_training, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + dropout_p, + softmax_scale, + qkv_format, + attn_mask_type, + attn_bias_type, + attn_bias, + deterministic, + use_fused_attention, + ] + + if cp_comm_type == "p2p": + args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream] + out = AttnFuncWithCPAndKVP2P.apply(*args) + elif cp_comm_type == "all_gather": + args.pop(5) + args.pop(8) + args += [window_size, cp_group, cp_stream] + out = AttnFuncWithCPAndKVAllGather.apply(*args) + elif cp_comm_type == "a2a": + args += [window_size, fp8, fp8_meta, cp_group, cp_stream] + out = AttnFuncWithCPAndQKVOA2A.apply(*args) else: raise ValueError(f"Unsupported communication type: {cp_comm_type}!") @@ -4047,9 +4680,8 @@ def run_iteratively(q, k, v): stride = q.stride() check_strides_qkv = all(stride == x.stride() for x in [q, k, v]) - stride = k.stride() - check_strides_kv = torch.equal( - torch.Tensor(stride[:-1]) / k.shape[-1], torch.Tensor(v.stride()[:-1]) / v.shape[-1] + check_strides_kv = tuple(sk / k.shape[-1] for sk in k.stride()[:-1]) == tuple( + sv / v.shape[-1] for sv in v.stride()[:-1] ) shape = q.shape @@ -4213,14 +4845,15 @@ def forward( cp_global_ranks: List[int] = None, cp_stream: torch.cuda.Stream = None, cp_comm_type: str = "p2p", + fp8: bool = False, + fp8_meta: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: """flash-attn fprop""" - assert ( - query_layer.dtype in [torch.float16, torch.bfloat16] - and key_layer.dtype in [torch.float16, torch.bfloat16] - and value_layer.dtype in [torch.float16, torch.bfloat16] - ), "FlashAttention currently only supports FP16 and BF16." + assert all( + x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor) + for x in [query_layer, key_layer, value_layer] + ), "FlashAttention only supports FP16 and BF16 data types, or Float8Tensors." assert ( query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda ), "FlashAttention currently only supports CUDA tensors." @@ -4233,24 +4866,36 @@ def forward( qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - if qkv_format == "sbhd": - # For now just 128, will make it more general in the future - if ( - query_layer.shape[-1] == 128 - and query_layer.shape[0] * query_layer.shape[1] >= 512 - and qkv_layout == "sbh3d" - ): - query_layer, key_layer, value_layer = _PrepareQKVForFA.apply( - query_layer, key_layer, value_layer - ) - else: + if all(not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): + if qkv_format == "sbhd": + # For now just 128, will make it more general in the future + if ( + query_layer.shape[-1] == 128 + and query_layer.shape[0] * query_layer.shape[1] >= 512 + and qkv_layout == "sbh3d" + ): + query_layer, key_layer, value_layer = _PrepareQKVForFA.apply( + query_layer, key_layer, value_layer + ) + else: + query_layer, key_layer, value_layer = [ + x.transpose(0, 1).contiguous() + for x in (query_layer, key_layer, value_layer) + ] + elif qkv_format in ["bshd", "thd"]: query_layer, key_layer, value_layer = [ - x.transpose(0, 1).contiguous() for x in (query_layer, key_layer, value_layer) + x.contiguous() for x in (query_layer, key_layer, value_layer) + ] + else: + if qkv_format == "sbhd": + query_layer._data, key_layer._data, value_layer._data = [ + x.transpose(0, 1).contiguous() + for x in (query_layer._data, key_layer._data, value_layer._data) + ] + elif qkv_format in ["bshd", "thd"]: + query_layer._data, key_layer._data, value_layer._data = [ + x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data) ] - elif qkv_format in ["bshd", "thd"]: - query_layer, key_layer, value_layer = [ - x.contiguous() for x in (query_layer, key_layer, value_layer) - ] batch_size = query_layer.shape[0] @@ -4258,16 +4903,15 @@ def forward( max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1] max_seqlen_q *= cp_size max_seqlen_kv *= cp_size - if not context_parallel: + + if "padding" in attn_mask_type: + assert not context_parallel, "Padding mask not supported with context parallelism!" # [b * s, h, d] query_layer, key_layer, value_layer = [ - x.view(x.shape[0] * x.shape[1], *x.shape[2:]) + x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]) for x in [query_layer, key_layer, value_layer] ] - if "padding" in attn_mask_type: - assert not context_parallel, "Padding mask not supported with context parallelism!" - if self.attention_type == "self": assert ( max_seqlen_q == max_seqlen_kv @@ -4320,7 +4964,9 @@ def forward( seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] max_seqlen_kv = seqlens_kv.max().item() - if context_parallel: + if context_parallel and all( + not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer] + ): assert ( alibi_slopes is None ), "Alibi slope bias addition is not supported with context parallelism." @@ -4367,34 +5013,94 @@ def forward( fa_optional_forward_kwargs["deterministic"] = self.deterministic if _flash_attn_2_5_7_plus: fa_optional_forward_kwargs["block_table"] = None - output = flash_attn_forward_func( - query_layer, - key_layer, - value_layer, - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, - self.attention_dropout if self.training else 0.0, - softmax_scale=self.softmax_scale, - causal="causal" in attn_mask_type, - **fa_optional_forward_kwargs, - ) + fa_optional_forward_args_thd = [] + if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type: + func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3 + else: + func = ( + flash_attn_varlen_func + if not _use_flash_attn_3 + else flash_attn_varlen_func_v3 + ) + fa_optional_forward_args_thd.append(cu_seqlens_q) + fa_optional_forward_args_thd.append(cu_seqlens_kv) + fa_optional_forward_args_thd.append(max_seqlen_q) + fa_optional_forward_args_thd.append(max_seqlen_kv) + if _use_flash_attn_3: + if fp8: + fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + activation_dtype = query_layer.dtype + torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True) + if fp8_meta["recipe"].fp8_mha: + assert all( + isinstance(x, Float8Tensor) + for x in [query_layer, key_layer, value_layer] + ), "q/k/v must be Float8Tensors for FP8 MHA." + fp8_meta["scaling_fwd"].scale_inv[META_QKV] = query_layer._scale_inv + query_layer, key_layer, value_layer = ( + x.to(activation_dtype).to(torch_dtype) + for x in [query_layer, key_layer, value_layer] + ) + else: + query_layer, key_layer, value_layer = ( + x.to(torch_dtype) for x in [query_layer, key_layer, value_layer] + ) + output, _ = func( + query_layer, + key_layer, + value_layer, + *fa_optional_forward_args_thd, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + deterministic=self.deterministic, + ) + if fp8 and fp8_meta["recipe"].fp8_mha: + output = cast_to_fp8( + output, + fp8_meta["scaling_fwd"], + META_O, + fp8_dtype_forward, + ) + output = Float8Tensor( + data=output, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=META_O, + fp8_dtype=fp8_dtype_forward, + dtype=activation_dtype, + ) + else: + output = func( + query_layer, + key_layer, + value_layer, + *fa_optional_forward_args_thd, + self.attention_dropout if self.training else 0.0, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + **fa_optional_forward_kwargs, + ) if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type: output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output) if qkv_format == "sbhd": # (bs)hd -> bs(hd) -> sb(hd) - output = ( - output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1).contiguous() - ) + if fp8 and fp8_meta["recipe"].fp8_mha: + output.reshape(batch_size * max_seqlen_q // cp_size, -1).transpose_2d() + output = output.reshape(batch_size, max_seqlen_q // cp_size, -1) + else: + output = ( + output.view(batch_size, max_seqlen_q // cp_size, -1) + .transpose(0, 1) + .contiguous() + ) elif qkv_format == "bshd": # (bs)hd -> bs(hd) - output = output.view(batch_size, max_seqlen_q // cp_size, -1).contiguous() + output = output.reshape(batch_size, max_seqlen_q // cp_size, -1) elif qkv_format == "thd": # thd -> t(hd) - output = output.view(output.shape[0], -1).contiguous() + output = output.reshape(output.shape[0], -1) return output @@ -4455,19 +5161,20 @@ def forward( fp8_meta, deterministic, ): + is_input_fp8 = False + is_output_fp8 = fp8_meta["recipe"].fp8_mha if fp8: - if fp8_meta["recipe"].fp8_mha: - assert isinstance(qkv, Float8Tensor), "qkv must be Float8Tensors for FP8 MHA." + is_input_fp8 = isinstance(qkv, Float8Tensor) + if is_input_fp8: fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv fused_attention_backend = FusedAttnBackend["FP8"] fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) # 1: qkv packed, 2: kv packed, 3: qkv separate qkv_group = len(qkv_layout.split("_")) - assert qkv_group == 1, ( - "qkv layout should conform to 3hd or h3d, e.g. sb3hd, but found" - f" {qkv_layout}." - ) - if fp8_meta["recipe"].fp8_mha: + assert ( + qkv_group == 1 + ), f"qkv layout should conform to 3hd or h3d, e.g. sb3hd, but found {qkv_layout}." + if is_input_fp8: qkv_fp8 = qkv._data else: qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) @@ -4483,12 +5190,18 @@ def forward( fused_attention_backend, attn_bias, cu_seqlens_padded, - fp8_meta["scaling_fwd"].scale_inv[META_QKV], - fp8_meta["scaling_fwd"].scale_inv[META_S], - fp8_meta["scaling_fwd"].scale[META_S], - fp8_meta["scaling_fwd"].scale[META_O], - fp8_meta["scaling_fwd"].amax_history[0][META_S], - fp8_meta["scaling_fwd"].amax_history[0][META_O], + fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv + META_QKV, # d_scale_qkv_offset + fp8_meta["scaling_fwd"].scale_inv, # d_scale_s + META_S, # d_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_s + META_S, # q_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_o + META_O, # q_scale_o_offset + fp8_meta["scaling_fwd"].amax_history, # amax_s + META_S, # amax_s_offset + fp8_meta["scaling_fwd"].amax_history, # amax_o + META_O, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -4498,7 +5211,7 @@ def forward( window_size, rng_gen, ) - if fp8_meta["recipe"].fp8_mha: + if is_output_fp8: out_ret = Float8Tensor( data=out_fp8, fp8_meta=fp8_meta, @@ -4516,22 +5229,24 @@ def forward( qkv_dtype, ).view(out_fp8.shape) out_save = out_ret - if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv = cast_from_fp8( - qkv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[qkv.dtype], - ).view(qkv.shape) - out_save = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) + if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + if is_input_fp8: + qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) + qkv = cast_from_fp8( + qkv_c._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[qkv.dtype], + ).view(qkv.shape) + if is_output_fp8: + out_save = cast_from_fp8( + out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), + fp8_meta["scaling_fwd"], + META_O, + fp8_dtype_forward, + qkv_dtype, + ).view(out_fp8.shape) fp8_tensors = ( qkv_fp8, out_fp8, @@ -4548,12 +5263,18 @@ def forward( fused_attention_backend, attn_bias, cu_seqlens_padded, - None, - None, - None, - None, - None, - None, + None, # d_scale_qkv + 0, # d_scale_qkv_offset + None, # d_scale_s + 0, # d_scale_s_offset + None, # q_scale_s + 0, # q_scale_s_offset + None, # q_scale_o + 0, # q_scale_o_offset + None, # amax_s + 0, # amax_s_offset + None, # amax_o + 0, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -4567,6 +5288,8 @@ def forward( out_save = out_ret ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ctx.is_input_fp8 = is_input_fp8 + ctx.is_output_fp8 = is_output_fp8 qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None) ctx.save_for_backward( *qkvo_tensors, cu_seqlens, cu_seqlens_padded, *fp8_tensors, *aux_ctx_tensors @@ -4591,7 +5314,7 @@ def forward( @staticmethod def backward(ctx, d_out): - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: assert isinstance( d_out, Float8Tensor ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." @@ -4648,7 +5371,7 @@ def backward(ctx, d_out): fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False ) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: d_out_fp8 = d_out ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv else: @@ -4688,7 +5411,7 @@ def backward(ctx, d_out): ctx.window_size, ctx.deterministic, ) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_input_fp8: dqkv = Float8Tensor( data=dqkv_fp8, fp8_meta=ctx.fp8_meta, @@ -4826,22 +5549,23 @@ def forward( fp8_meta, deterministic, ): + is_input_fp8 = False + is_output_fp8 = fp8_meta["recipe"].fp8_mha if fp8: - if fp8_meta["recipe"].fp8_mha: - assert isinstance(q, Float8Tensor) and isinstance( - kv, Float8Tensor - ), "q/kv must be Float8Tensors for FP8 MHA." + assert isinstance(kv, q.__class__), "q and kv must have the same type." + is_input_fp8 = isinstance(q, Float8Tensor) + if is_input_fp8: fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv fused_attention_backend = FusedAttnBackend["FP8"] fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if fp8_meta["recipe"].fp8_mha: + if is_input_fp8: q_fp8, kv_fp8 = q._data, kv._data else: # 1: qkv packed, 2: kv packed, 3: qkv separate qkv_group = len(qkv_layout.split("_")) assert qkv_group == 2, ( - "qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, " - f" but found {qkv_layout}." + "qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, " + f"but found {qkv_layout}." ) q_fp8 = cast_to_fp8(q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward).view( q.shape @@ -4863,12 +5587,18 @@ def forward( attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - fp8_meta["scaling_fwd"].scale_inv[META_QKV], - fp8_meta["scaling_fwd"].scale_inv[META_S], - fp8_meta["scaling_fwd"].scale[META_S], - fp8_meta["scaling_fwd"].scale[META_O], - fp8_meta["scaling_fwd"].amax_history[0][META_S], - fp8_meta["scaling_fwd"].amax_history[0][META_O], + fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv + META_QKV, # d_scale_qkv_offset + fp8_meta["scaling_fwd"].scale_inv, # d_scale_s + META_S, # d_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_s + META_S, # q_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_o + META_O, # q_scale_o_offset + fp8_meta["scaling_fwd"].amax_history, # amax_s + META_S, # amax_s_offset + fp8_meta["scaling_fwd"].amax_history, # amax_o + META_O, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -4878,7 +5608,7 @@ def forward( window_size, rng_gen, ) - if fp8_meta["recipe"].fp8_mha: + if is_output_fp8: out_ret = Float8Tensor( data=out_fp8, fp8_meta=fp8_meta, @@ -4896,25 +5626,31 @@ def forward( qkv_dtype, ).view(out_fp8.shape) out_save = out_ret - if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q = cast_from_fp8( - q._data, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward, TE_DType[q.dtype] - ).view(q.shape) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv = cast_from_fp8( - kv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[kv.dtype], - ).view(kv.shape) - out_save = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) + if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + if is_input_fp8: + q = cast_from_fp8( + q._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[q.dtype], + ).view(q.shape) + kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) + kv = cast_from_fp8( + kv_c._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[kv.dtype], + ).view(kv.shape) + if is_output_fp8: + out_save = cast_from_fp8( + out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), + fp8_meta["scaling_fwd"], + META_O, + fp8_dtype_forward, + qkv_dtype, + ).view(out_fp8.shape) fp8_tensors = ( q_fp8, kv_fp8, @@ -4936,12 +5672,18 @@ def forward( attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - None, - None, - None, - None, - None, - None, + None, # d_scale_qkv + 0, # d_scale_qkv_offset + None, # d_scale_s + 0, # d_scale_s_offset + None, # q_scale_s + 0, # q_scale_s_offset + None, # q_scale_o + 0, # q_scale_o_offset + None, # amax_s + 0, # amax_s_offset + None, # amax_o + 0, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -4955,6 +5697,8 @@ def forward( fp8_tensors = (None, None, None, None, None) ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ctx.is_input_fp8 = is_input_fp8 + ctx.is_output_fp8 = is_output_fp8 qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None) ctx.save_for_backward( *qkvo_tensors, @@ -4986,7 +5730,7 @@ def forward( @staticmethod def backward(ctx, d_out): - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: assert isinstance( d_out, Float8Tensor ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." @@ -5047,7 +5791,7 @@ def backward(ctx, d_out): fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False ) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: d_out_fp8 = d_out ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv else: @@ -5091,7 +5835,7 @@ def backward(ctx, d_out): ctx.window_size, ctx.deterministic, ) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_input_fp8: dq = Float8Tensor( data=dq_fp8, fp8_meta=ctx.fp8_meta, @@ -5257,15 +6001,16 @@ def forward( fp8_meta, deterministic, ): + is_input_fp8 = False + is_output_fp8 = fp8_meta["recipe"].fp8_mha if fp8: fused_attention_backend = FusedAttnBackend["FP8"] fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if fp8_meta["recipe"].fp8_mha: - assert ( - isinstance(q, Float8Tensor) - and isinstance(k, Float8Tensor) - and isinstance(v, Float8Tensor) - ), "q/k/v must be Float8Tensors for FP8 MHA." + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, and v must have the same type." + is_input_fp8 = isinstance(q, Float8Tensor) + if is_input_fp8: fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data else: @@ -5316,12 +6061,18 @@ def forward( attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - fp8_meta["scaling_fwd"].scale_inv[META_QKV], - fp8_meta["scaling_fwd"].scale_inv[META_S], - fp8_meta["scaling_fwd"].scale[META_S], - fp8_meta["scaling_fwd"].scale[META_O], - fp8_meta["scaling_fwd"].amax_history[0][META_S], - fp8_meta["scaling_fwd"].amax_history[0][META_O], + fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv + META_QKV, # d_scale_qkv_offset + fp8_meta["scaling_fwd"].scale_inv, # d_scale_s + META_S, # d_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_s + META_S, # q_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_o + META_O, # q_scale_o_offset + fp8_meta["scaling_fwd"].amax_history, # amax_s + META_S, # amax_s_offset + fp8_meta["scaling_fwd"].amax_history, # amax_o + META_O, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -5331,7 +6082,7 @@ def forward( window_size, rng_gen, ) - if fp8_meta["recipe"].fp8_mha: + if is_output_fp8: out_ret = Float8Tensor( data=out_fp8, fp8_meta=fp8_meta, @@ -5350,71 +6101,73 @@ def forward( ).view(out_fp8.shape) out_save = out_ret - if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.split("_")) - if qkv_group == 1: - dim = qkv_layout.find("3") - qkv = _combine_tensors([q, k, v], dim) - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_no_fp8 = cast_from_fp8( - qkv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[qkv.dtype], - ).view(qkv.shape) - q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1]) - q, k, v = [x.squeeze(dim) for x in [q, k, v]] - if qkv_group == 2: - q = cast_from_fp8( - q._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[q.dtype], - ).view(q.shape) - dim = qkv_layout.split("_")[1].find("2") - kv = _combine_tensors([k, v], dim) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_no_fp8 = cast_from_fp8( - kv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[kv.dtype], - ).view(kv.shape) - k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1]) - k, v = [x.squeeze(dim) for x in [k, v]] - if qkv_group == 3: - q = cast_from_fp8( - q._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[q.dtype], - ).view(q.shape) - k = cast_from_fp8( - k._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[k.dtype], - ).view(k.shape) - v = cast_from_fp8( - v._data, + if is_input_fp8: + qkv_group = len(qkv_layout.split("_")) + if qkv_group == 1: + dim = qkv_layout.find("3") + qkv = _combine_tensors([q, k, v], dim) + qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) + qkv_no_fp8 = cast_from_fp8( + qkv_c._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[qkv.dtype], + ).view(qkv.shape) + q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1]) + q, k, v = [x.squeeze(dim) for x in [q, k, v]] + if qkv_group == 2: + q = cast_from_fp8( + q._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[q.dtype], + ).view(q.shape) + dim = qkv_layout.split("_")[1].find("2") + kv = _combine_tensors([k, v], dim) + kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) + kv_no_fp8 = cast_from_fp8( + kv_c._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[kv.dtype], + ).view(kv.shape) + k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1]) + k, v = [x.squeeze(dim) for x in [k, v]] + if qkv_group == 3: + q = cast_from_fp8( + q._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[q.dtype], + ).view(q.shape) + k = cast_from_fp8( + k._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[k.dtype], + ).view(k.shape) + v = cast_from_fp8( + v._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[v.dtype], + ).view(v.shape) + if is_output_fp8: + out_save = cast_from_fp8( + out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), fp8_meta["scaling_fwd"], - META_QKV, + META_O, fp8_dtype_forward, - TE_DType[v.dtype], - ).view(v.shape) - out_save = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) + qkv_dtype, + ).view(out_fp8.shape) fp8_tensors = ( q_fp8, @@ -5439,12 +6192,18 @@ def forward( attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - None, - None, - None, - None, - None, - None, + None, # d_scale_qkv + 0, # d_scale_qkv_offset + None, # d_scale_s + 0, # d_scale_s_offset + None, # q_scale_s + 0, # q_scale_s_offset + None, # q_scale_o + 0, # q_scale_o_offset + None, # amax_s + 0, # amax_s_offset + None, # amax_o + 0, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -5457,16 +6216,25 @@ def forward( out_save = out_ret fp8_tensors = (None, None, None, None, None, None) + ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + from .cpu_offload import CPUOffloadEnabled if CPUOffloadEnabled: - tensor_list = [q, k, v, out_save, cu_seqlens_q, cu_seqlens_kv] + if ctx.fp8: + tensor_list = fp8_tensors + else: + tensor_list = [q, k, v, out_save] + + tensor_list.extend(aux_ctx_tensors) + qkv_layout = "sbhd_sbhd_sbhd" for tensor in tensor_list: if tensor is not None: tensor.activation_offloading = True - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ctx.is_input_fp8 = is_input_fp8 + ctx.is_output_fp8 = is_output_fp8 qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) ctx.save_for_backward( *qkvo_tensors, @@ -5498,7 +6266,7 @@ def forward( @staticmethod def backward(ctx, d_out): - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: assert isinstance( d_out, Float8Tensor ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." @@ -5563,7 +6331,7 @@ def backward(ctx, d_out): fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False ) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: d_out_fp8 = d_out ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv else: @@ -5609,7 +6377,7 @@ def backward(ctx, d_out): ctx.deterministic, ) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_input_fp8: dq = Float8Tensor( data=dq_fp8, fp8_meta=ctx.fp8_meta, @@ -5898,11 +6666,10 @@ def forward( assert ( fused_attention_backend != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend ), "No fused attention backend supports this input combination!" - assert ( - (query_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8]) - and (key_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8]) - and (value_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8]) - ), "FusedAttention only supports FP16 and BF16 data types." + assert all( + x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor) + for x in [query_layer, key_layer, value_layer] + ), "FusedAttention only supports FP16 and BF16 data types, or Float8Tensors." assert ( query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda ), "FusedAttention only supports CUDA tensors." @@ -6167,7 +6934,13 @@ class DotProductAttention(TransformerEngineBaseModule): can overlap two flash attention kernels. cp_comm_type : str inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather". + Can be "p2p" or "all_gather" or "a2a". + "p2p": Exchange KV chunks with P2P communications in ring topology. + P2P is async and can be overlapped with attention compute. + "all_gather": All-gather to get full sequence of KV before attention. + The all-gather is not async, and cannot be overlapped. + "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP + group, and gather to get full sequence of QKV. """ def __init__( @@ -6359,7 +7132,13 @@ def set_context_parallel_group( cuda stream for context parallel execution. cp_comm_type : str inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather". + Can be "p2p" or "all_gather" or "a2a". + "p2p": Exchange KV chunks with P2P communications in ring topology. + P2P is async and can be overlapped with attention compute. + "all_gather": All-gather to get full sequence of KV before attention. + The all-gather is not async, and cannot be overlapped. + "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP + group, and gather to get full sequence of QKV. """ self.cp_group = cp_group self.cp_global_ranks = cp_global_ranks @@ -6813,7 +7592,7 @@ def forward( fp8=self.fp8, fp8_meta=self.fp8_meta, ) - global _attention_backends + global _attention_backends, _flash_attn_3_plus, _use_flash_attn_3 if ( _attention_backends["attention_params"] is None or attention_params != _attention_backends["attention_params"] @@ -6821,6 +7600,7 @@ def forward( _attention_backends["attention_params"] = attention_params _attention_backends["backend_selection_requires_update"] = True if _attention_backends["backend_selection_requires_update"]: + _use_flash_attn_3 = _flash_attn_3_plus ( use_flash_attention, use_fused_attention, @@ -6829,7 +7609,10 @@ def forward( _, ) = get_attention_backend(attention_params) if use_flash_attention: - self.logger.info("Running with FlashAttention backend") + self.logger.info( + "Running with FlashAttention backend (version %s)", + _flash_attn_version if not _use_flash_attn_3 else _flash_attn_v3_version, + ) elif use_fused_attention: self.logger.info( "Running with FusedAttention backend (sub-backend %s)", @@ -6868,6 +7651,8 @@ def forward( cp_comm_type=self.cp_comm_type, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, + fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, + fp8_meta=self.fp8_meta, ) if use_fused_attention: @@ -7071,7 +7856,7 @@ class MultiheadAttention(torch.nn.Module): bias : bool, default = `True` if set to `False`, the transformer layer will not learn any additive biases. device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will allocated. It is the user's + The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. qkv_format: str, default = `sbhd` @@ -7378,7 +8163,13 @@ def set_context_parallel_group( cuda stream for context parallel execution. cp_comm_type : str inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather". + Can be "p2p" or "all_gather" or "a2a". + "p2p": Exchange KV chunks with P2P communications in ring topology. + P2P is async and can be overlapped with attention compute. + "all_gather": All-gather to get full sequence of KV before attention. + The all-gather is not async, and cannot be overlapped. + "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP + group, and gather to get full sequence of QKV. """ # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(self.modules()): @@ -7534,12 +8325,18 @@ def forward( # Query, Key, and Value # ====================== + fp8_mha = ( + FP8GlobalStateManager.is_fp8_enabled() + and FP8GlobalStateManager.get_fp8_recipe().fp8_mha + ) + if self.attention_type == "self": # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn] if self.input_layernorm: layernorm_qkv_outputs = self.layernorm_qkv( hidden_states, is_first_microbatch=is_first_microbatch, + fp8_output=fp8_mha and rotary_pos_emb is None, ) if self.return_layernorm_output: mixed_x_layer, layernorm_output = layernorm_qkv_outputs @@ -7549,7 +8346,7 @@ def forward( mixed_x_layer = self.qkv( hidden_states, is_first_microbatch=is_first_microbatch, - is_first_module_in_mha=True, # specific to FP8 MHA + fp8_output=fp8_mha and rotary_pos_emb is None, ) num_queries_per_key_value = ( @@ -7610,7 +8407,7 @@ def forward( mixed_kv_layer = self.key_value( encoder_output, is_first_microbatch=is_first_microbatch, - is_first_module_in_mha=True, # specific to FP8 MHA + fp8_output=fp8_mha and rotary_pos_emb is None, ) if self.qkv_weight_interleaved: @@ -7660,6 +8457,7 @@ def forward( layernorm_query_outputs = self.layernorm_query( hidden_states, is_first_microbatch=is_first_microbatch, + fp8_output=fp8_mha and rotary_pos_emb is None, ) if self.return_layernorm_output: query_layer, layernorm_output = layernorm_query_outputs @@ -7669,7 +8467,7 @@ def forward( query_layer = self.query_layer( hidden_states, is_first_microbatch=is_first_microbatch, - is_first_module_in_mha=True, # specific to FP8 MHA + fp8_output=fp8_mha and rotary_pos_emb is None, ) # [sq, b, hp] --> [sq, b, np, hn] diff --git a/transformer_engine/pytorch/cpp_extensions/__init__.py b/transformer_engine/pytorch/cpp_extensions/__init__.py index 61d688f3f4..9f3c1b2424 100644 --- a/transformer_engine/pytorch/cpp_extensions/__init__.py +++ b/transformer_engine/pytorch/cpp_extensions/__init__.py @@ -11,3 +11,4 @@ from .activation import * from .normalization import * from .cast import * +from .padding import * diff --git a/transformer_engine/pytorch/cpp_extensions/_common.py b/transformer_engine/pytorch/cpp_extensions/_common.py new file mode 100644 index 0000000000..b9d7288dfa --- /dev/null +++ b/transformer_engine/pytorch/cpp_extensions/_common.py @@ -0,0 +1,87 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Helper functions for C++ extensions""" +import functools +from typing import Dict, Optional, Tuple, Union + +import torch + +import transformer_engine_torch as tex + + +@functools.lru_cache(maxsize=None) +def empty_tensor() -> torch.Tensor: + """Get tensor with no entries and no data""" + return torch.Tensor() + + +def canonicalize_fp8_scales( + *, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + fp8_meta: Optional[tex.FP8TensorMeta] = None, + fp8_meta_index: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None] = None, + allow_multiple_offsets: bool = True, +) -> Tuple[Dict[str, torch.Tensor], Dict[str, int]]: + """Canonicalize FP8 scaling factors (scale, amax, scale-inverse) + + If a scaling factor is not provided, try to access it within the + FP8 meta tensors. Returns dict with tensors and dict with tensor + offsets. + + """ + + # Default: use provided scales with no offsets + scale_offset = 0 + amax_offset = 0 + scale_inv_offset = 0 + + # Get scales from FP8 meta tensors if needed + if (fp8_meta is not None) and any(arg is None for arg in (scale, amax, scale_inv)): + if fp8_meta_index is None: + raise ValueError("Provided `fp8_meta` without corresponding `fp8_meta_index`") + fp8_meta_index = int(fp8_meta_index) + if scale is None: + scale = fp8_meta.scale + scale_offset = fp8_meta_index + if amax is None: + amax = fp8_meta.amax_history + amax_offset = fp8_meta_index + if scale_inv is None: + scale_inv = fp8_meta.scale_inv + scale_inv_offset = fp8_meta_index + + # Construct empty tensors if needed + if scale is None: + scale = empty_tensor() + scale_offset = 0 + if amax is None: + amax = empty_tensor() + amax_offset = 0 + if scale_inv is None: + scale_inv = empty_tensor() + scale_inv_offset = 0 + + # Force offsets to be the same if needed + if not allow_multiple_offsets and not scale_offset == amax_offset == scale_inv_offset: + if scale_offset != 0: + scale = scale[scale_offset:] + scale_offset = 0 + if amax_offset != 0: + amax = amax[:, amax_offset:] + amax_offset = 0 + if scale_inv_offset != 0: + scale_inv = scale_inv[scale_inv_offset:] + scale_inv_offset = 0 + + # Pack tensors and offsets into dicts + tensors = dict(scale=scale, amax=amax, scale_inv=scale_inv) + offsets = dict( + scale_offset=scale_offset, + amax_offset=amax_offset, + scale_inv_offset=scale_inv_offset, + ) + return tensors, offsets diff --git a/transformer_engine/pytorch/cpp_extensions/activation.py b/transformer_engine/pytorch/cpp_extensions/activation.py index 767fe25291..f204982aa0 100644 --- a/transformer_engine/pytorch/cpp_extensions/activation.py +++ b/transformer_engine/pytorch/cpp_extensions/activation.py @@ -3,192 +3,235 @@ # See LICENSE for license information. """Python interface for activation extensions""" -from typing import Union +from typing import Optional, Union + import torch -import transformer_engine_torch as tex +import transformer_engine_torch as tex +from ._common import canonicalize_fp8_scales __all__ = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"] def gelu( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """GeLU with FP8 output""" - empty_tensor = torch.Tensor() - if fp8_meta_tensor is not None: - scale = fp8_meta_tensor.scale - amax_history = fp8_meta_tensor.amax_history - scale_inv = fp8_meta_tensor.scale_inv - else: - scale = empty_tensor - amax_history = empty_tensor - scale_inv = empty_tensor + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel return torch.ops.tex_ts.gelu_ts( inp, - scale, - amax_history, - scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, ) def relu( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ReLU with FP8 output""" - empty_tensor = torch.Tensor() - if fp8_meta_tensor is not None: - scale = fp8_meta_tensor.scale - amax_history = fp8_meta_tensor.amax_history - scale_inv = fp8_meta_tensor.scale_inv - else: - scale = empty_tensor - amax_history = empty_tensor - scale_inv = empty_tensor + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel return torch.ops.tex_ts.relu_ts( inp, - scale, - amax_history, - scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, ) def geglu( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """GeGLU with FP8 output""" - empty_tensor = torch.Tensor() - if fp8_meta_tensor is not None: - scale = fp8_meta_tensor.scale - amax_history = fp8_meta_tensor.amax_history - scale_inv = fp8_meta_tensor.scale_inv - else: - scale = empty_tensor - amax_history = empty_tensor - scale_inv = empty_tensor + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel return torch.ops.tex_ts.geglu_ts( inp, - scale, - amax_history, - scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, ) def reglu( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ReGLU with FP8 output""" - empty_tensor = torch.Tensor() - if fp8_meta_tensor is not None: - scale = fp8_meta_tensor.scale - amax_history = fp8_meta_tensor.amax_history - scale_inv = fp8_meta_tensor.scale_inv - else: - scale = empty_tensor - amax_history = empty_tensor - scale_inv = empty_tensor + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel return torch.ops.tex_ts.reglu_ts( inp, - scale, - amax_history, - scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, ) def swiglu( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """SwiGLU with FP8 output""" - empty_tensor = torch.Tensor() - if fp8_meta_tensor is not None: - scale = fp8_meta_tensor.scale - amax_history = fp8_meta_tensor.amax_history - scale_inv = fp8_meta_tensor.scale_inv - else: - scale = empty_tensor - amax_history = empty_tensor - scale_inv = empty_tensor + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel return torch.ops.tex_ts.swiglu_ts( inp, - scale, - amax_history, - scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, ) def qgelu( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """QuickGELU with FP8 output""" - empty_tensor = torch.Tensor() - if fp8_meta_tensor is not None: - scale = fp8_meta_tensor.scale - amax_history = fp8_meta_tensor.amax_history - scale_inv = fp8_meta_tensor.scale_inv - else: - scale = empty_tensor - amax_history = empty_tensor - scale_inv = empty_tensor + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel return torch.ops.tex_ts.qgelu_ts( inp, - scale, - amax_history, - scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, ) def srelu( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ReLU with FP8 output""" - empty_tensor = torch.Tensor() - if fp8_meta_tensor is not None: - scale = fp8_meta_tensor.scale - amax_history = fp8_meta_tensor.amax_history - scale_inv = fp8_meta_tensor.scale_inv - else: - scale = empty_tensor - amax_history = empty_tensor - scale_inv = empty_tensor + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel return torch.ops.tex_ts.srelu_ts( inp, - scale, - amax_history, - scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, ) diff --git a/transformer_engine/pytorch/cpp_extensions/cast.py b/transformer_engine/pytorch/cpp_extensions/cast.py index 2856d4727b..cd3c01c785 100644 --- a/transformer_engine/pytorch/cpp_extensions/cast.py +++ b/transformer_engine/pytorch/cpp_extensions/cast.py @@ -4,57 +4,90 @@ """Python interface for cast extensions""" from typing import Optional, Union + import torch -import transformer_engine_torch as tex +import transformer_engine_torch as tex +from ._common import canonicalize_fp8_scales __all__ = ["cast_to_fp8", "cast_from_fp8"] def cast_to_fp8( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, out: Optional[torch.Tensor] = None, -) -> Optional[torch.Tensor]: + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, +) -> torch.Tensor: """Cast input to FP8""" - if out is not None: - if inp.nelement() > 0: - torch.ops.tex_ts.cast_to_fp8_noalloc_ts( - inp, - fp8_meta_tensor.scale, - out, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor, - otype, - ) - return None - - return torch.ops.tex_ts.cast_to_fp8_ts( - inp, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor, - otype, + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, ) + # Launch FP8 cast kernel + if inp.nelement() == 0: + if out is None: + out = torch.empty_like(inp, dtype=torch.uint8) + elif out is None: + out = torch.ops.tex_ts.cast_to_fp8_ts( + inp, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], + otype, + ) + else: + torch.ops.tex_ts.cast_to_fp8_noalloc_ts( + inp, + fp8_scales["scale"], + out, + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], + otype, + ) + return out + def cast_from_fp8( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], itype: tex.DType, otype: tex.DType, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Cast input from FP8""" + + # Get scaling factors from FP8 meta tensors if needed + scale_inv_offset = 0 + if (fp8_meta_tensor is not None) and (scale_inv is None): + if fp8_tensor is None: + raise ValueError("Provided `fp8_meta_tensor` without corresponding `fp8_tensor`") + scale_inv = fp8_meta_tensor.scale_inv + scale_inv_offset = int(fp8_tensor) + + # Construct empty tensors if needed + if scale_inv is None: + raise ValueError("Did not provide either `scale_inv` or `fp8_meta_tensor`") + + # Launch FP8 cast kernel return torch.ops.tex_ts.cast_from_fp8_ts( inp, - fp8_meta_tensor.scale_inv, - fp8_tensor, + scale_inv, + scale_inv_offset, itype, otype, ) diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index d0ba644621..cd0ecbaa6c 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -78,6 +78,16 @@ BACKEND_F16m512_FP8_THREADS_PER_CTA = 128 BACKEND_F16arb_ELTS_PER_THREADS = 16 +META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT +META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1 +META_O = tex.FP8FwdTensors.GEMM2_INPUT +META_DO = tex.FP8BwdTensors.GRAD_INPUT2 +META_S = tex.FP8FwdTensors.GEMM3_OUTPUT +META_DP = tex.FP8BwdTensors.GRAD_INPUT3 +# repurpose some unused amax history buffers for partial results of CP fwd and bwd +META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT +META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1 + def fused_attn_fwd_qkvpacked( is_training: bool, @@ -89,11 +99,17 @@ def fused_attn_fwd_qkvpacked( attn_bias: torch.Tensor = None, cu_seqlens_padded: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, + d_scale_qkv_offset: int = META_QKV, d_scale_s: torch.Tensor = None, + d_scale_s_offset: int = META_S, q_scale_s: torch.Tensor = None, + q_scale_s_offset: int = META_S, q_scale_o: torch.Tensor = None, + q_scale_o_offset: int = META_O, amax_s: torch.Tensor = None, + amax_s_offset: int = META_S, amax_o: torch.Tensor = None, + amax_o_offset: int = META_O, attn_scale: float = None, dropout: float = 0.0, fast_zero_fill: bool = True, @@ -128,16 +144,28 @@ def fused_attn_fwd_qkvpacked( cumulative sequence offsets for QKV; shape [batch_size + 1] d_scale_qkv: torch.Tensor, default = None input tensor for the dequantization of QKV in FP8 computations + d_scale_qkv_offset: int, default = META_QKV + offset in d_scale_qkv for QKV d_scale_s: torch.Tensor, default = None input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) + d_scale_s_offset: int, default = META_S + offset in d_scale_s for S q_scale_s: torch.Tensor, default = None input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) + q_scale_s_offset: int, default = META_S + offset in q_scale_s for S q_scale_o: torch.Tensor, default = None input tensor for the quantization of O in FP8 computations + q_scale_o_offset: int, default = META_O + offset in q_scale_o for O amax_s: torch.Tensor, default = None output tensor, amax of S, used by the next iteration in FP8 computations + amax_s_offset: int, default = META_S + offset in amax_s for S amax_o: torch.Tensor, default = None output tensor, amax of O, used by the next iteration in FP8 computations + amax_o_offset: int, default = META_O + offset in amax_o for O attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; if None, use 1.0/sqrt(head_dim_qk) as the default @@ -248,11 +276,17 @@ def fused_attn_fwd_qkvpacked( qkv_dtype, cu_seqlens_padded, d_scale_qkv, + d_scale_qkv_offset, d_scale_s, + d_scale_s_offset, q_scale_s, + q_scale_s_offset, q_scale_o, + q_scale_o_offset, amax_s, + amax_s_offset, amax_o, + amax_o_offset, attn_bias, rng_gen, rng_elts_per_thread, @@ -448,11 +482,17 @@ def fused_attn_fwd_kvpacked( cu_seqlens_q_padded: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, + d_scale_qkv_offset: int = META_QKV, d_scale_s: torch.Tensor = None, + d_scale_s_offset: int = META_S, q_scale_s: torch.Tensor = None, + q_scale_s_offset: int = META_S, q_scale_o: torch.Tensor = None, + q_scale_o_offset: int = META_O, amax_s: torch.Tensor = None, + amax_s_offset: int = META_S, amax_o: torch.Tensor = None, + amax_o_offset: int = META_O, attn_scale: float = None, dropout: float = 0.0, fast_zero_fill: bool = True, @@ -496,16 +536,28 @@ def fused_attn_fwd_kvpacked( cumulative sequence offsets for KV; shape [batch_size + 1] d_scale_qkv: torch.Tensor, default = None input tensor for the dequantization of QKV in FP8 computations + d_scale_qkv_offset: int, default = META_QKV + offset in d_scale_qkv for QKV d_scale_s: torch.Tensor, default = None input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) + d_scale_s_offset: int, default = META_S + offset in d_scale_s for S q_scale_s: torch.Tensor, default = None input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) + q_scale_s_offset: int, default = META_S + offset in q_scale_s for S q_scale_o: torch.Tensor, default = None input tensor for the quantization of O in FP8 computations + q_scale_o_offset: int, default = META_O + offset in q_scale_o for O amax_s: torch.Tensor, default = None output tensor, amax of S, used by the next iteration in FP8 computations + amax_s_offset: int, default = META_S + offset in amax_s for S amax_o: torch.Tensor, default = None output tensor, amax of O, used by the next iteration in FP8 computations + amax_o_offset: int, default = META_O + offset in amax_o for O attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; if None, use 1.0/sqrt(head_dim_qk) as the default @@ -621,11 +673,17 @@ def fused_attn_fwd_kvpacked( cu_seqlens_q_padded, cu_seqlens_kv_padded, d_scale_qkv, + d_scale_qkv_offset, d_scale_s, + d_scale_s_offset, q_scale_s, + q_scale_s_offset, q_scale_o, + q_scale_o_offset, amax_s, + amax_s_offset, amax_o, + amax_o_offset, attn_bias, rng_gen, rng_elts_per_thread, @@ -843,11 +901,17 @@ def fused_attn_fwd( cu_seqlens_q_padded: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, + d_scale_qkv_offset: int = META_QKV, d_scale_s: torch.Tensor = None, + d_scale_s_offset: int = META_S, q_scale_s: torch.Tensor = None, + q_scale_s_offset: int = META_S, q_scale_o: torch.Tensor = None, + q_scale_o_offset: int = META_O, amax_s: torch.Tensor = None, + amax_s_offset: int = META_S, amax_o: torch.Tensor = None, + amax_o_offset: int = META_O, attn_scale: float = None, dropout: float = 0.0, fast_zero_fill: bool = True, @@ -894,17 +958,29 @@ def fused_attn_fwd( cu_seqlens_kv_padded: torch.Tensor, default = None cumulative sequence offsets for KV; shape [batch_size + 1] d_scale_qkv: torch.Tensor, default = None - input tensor for the dequantization of Q, K and V in FP8 computations + input tensor for the dequantization of QKV in FP8 computations + d_scale_qkv_offset: int, default = META_QKV + offset in d_scale_qkv for QKV d_scale_s: torch.Tensor, default = None input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) + d_scale_s_offset: int, default = META_S + offset in d_scale_s for S q_scale_s: torch.Tensor, default = None input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) + q_scale_s_offset: int, default = META_S + offset in q_scale_s for S q_scale_o: torch.Tensor, default = None input tensor for the quantization of O in FP8 computations + q_scale_o_offset: int, default = META_O + offset in q_scale_o for O amax_s: torch.Tensor, default = None output tensor, amax of S, used by the next iteration in FP8 computations + amax_s_offset: int, default = META_S + offset in amax_s for S amax_o: torch.Tensor, default = None output tensor, amax of O, used by the next iteration in FP8 computations + amax_o_offset: int, default = META_O + offset in amax_o for O attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; if None, use 1.0/sqrt(head_dim_qk) as the default @@ -1023,11 +1099,17 @@ def fused_attn_fwd( cu_seqlens_q_padded, cu_seqlens_kv_padded, d_scale_qkv, + d_scale_qkv_offset, d_scale_s, + d_scale_s_offset, q_scale_s, + q_scale_s_offset, q_scale_o, + q_scale_o_offset, amax_s, + amax_s_offset, amax_o, + amax_o_offset, attn_bias, rng_gen, rng_elts_per_thread, diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 8502f70491..fd1eb4a810 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -11,7 +11,12 @@ from ..utils import assert_dim_for_fp8_exec -__all__ = ["gemm", "fp8_gemm", "grouped_gemm", "fp8_grouped_gemm"] +__all__ = [ + "gemm", + "fp8_gemm", + "grouped_gemm", + "fp8_grouped_gemm", +] @functools.lru_cache(maxsize=None) @@ -313,7 +318,7 @@ def grouped_gemm( layout: str = "TN", bias: Optional[List[torch.Tensor]] = None, use_bias: bool = False, -) -> Tuple[Union[List[torch.Tensor], None], ...]: +) -> Tuple[List[torch.Tensor], ...]: """Non FP8 Grouped GEMM.""" assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." @@ -380,7 +385,7 @@ def grouped_gemm( def fp8_grouped_gemm( A: List[torch.Tensor], - A_scale_inv: torch.Tensor, + A_scale_inv: List[torch.Tensor], A_fp8_tensor_offset: int, A_dtype: tex.DType, B: List[torch.Tensor], @@ -390,6 +395,7 @@ def fp8_grouped_gemm( out: List[torch.Tensor], out_dtype: torch.dtype, workspaces: List[torch.Tensor], + m_splits: Optional[List[int]] = None, out_offset: Optional[int] = None, fp8_meta_tensor: tex.FP8TensorMeta = None, gelu: bool = False, @@ -398,16 +404,25 @@ def fp8_grouped_gemm( use_bias: bool = False, use_split_accumulator: bool = False, D_dtype: Optional[tex.DType] = None, -) -> Tuple[Union[List[torch.Tensor], None], ...]: +) -> Tuple[List[torch.Tensor], ...]: """ TN layout Grouped GEMM with fp8 inputs. - This method assumes the scale/scale_inv/amax of A/B/out is contiguous in the meta tensor. - scale: [ ...A_scale... | ...B_scale... | ...out_scale...] - scale_inv: [ ...A_scale_inv... | ...B_scale_inv... | ...out_scale_inv...] - amax: [ ...A_amax... | ...B_amax... | ...out_amax...] + Input requirements: + 1. If len(A_scale_inv) == num_gemms, len(out) must be 1, and m_splits is not None. + This is used for the calculation of output (fwd) and dgrad (bwd). + 2. if len(A_scale_inv) == 1, len(out) must be num_gemms. This is used for the + calculation of wgrad. """ - num_gemms = len(A) + if num_gemms > 1 and len(A_scale_inv) == num_gemms: + assert len(out) == 1 and m_splits is not None + elif num_gemms > 1 and len(A_scale_inv) == 1: + assert len(out) == num_gemms + elif num_gemms == 1: + assert len(A_scale_inv) == 1 and len(out) == 1 + else: + raise ValueError("Invalid input combinations of A_scale_inv and out.") + empty_tensor = _empty_tensor() empty_tensors = [empty_tensor] * num_gemms if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: @@ -420,41 +435,71 @@ def fp8_grouped_gemm( # Use bfloat16 as default bias_dtype bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype - if gelu: - gelu_input = [ - torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format) - for o in out - ] - else: - gelu_input = empty_tensors bias_dtype = TE_DType[bias_dtype] - + gelu_input = empty_tensors out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype - torch.ops.tex_ts.te_grouped_gemm_ts( - A, - A_scale_inv, - A_fp8_tensor_offset, - A_dtype, - True, # transa - B, - B_scale_inv, - B_fp8_tensor_offset, - B_dtype, - False, # transb - out, - 0 if out_offset is None else out_offset, - empty_tensor if out_offset is None else fp8_meta_tensor.scale, - out_dtype, - empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, - bias if use_bias else empty_tensors, - bias_dtype, - gelu_input, # this is pre_gelu_out - False, # grad - workspaces, - workspaces[0].shape[0], - accumulate, - use_split_accumulator, - ) + if len(A_scale_inv) == 1: + if gelu: + gelu_input = [ + torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format) + for o in out + ] + + torch.ops.tex_ts.te_grouped_gemm_ts( + A, + A_scale_inv[0], + A_fp8_tensor_offset, + A_dtype, + True, # transa + B, + B_scale_inv, + B_fp8_tensor_offset, + B_dtype, + False, # transb + out, + 0 if out_offset is None else out_offset, + empty_tensor if out_offset is None else fp8_meta_tensor.scale, + out_dtype, + empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, + bias if use_bias else empty_tensors, + bias_dtype, + gelu_input, # this is pre_gelu_out + False, # grad + workspaces, + workspaces[0].shape[0], + accumulate, + use_split_accumulator, + ) + else: + if gelu: + gelu_input = [torch.empty((m, A[0].size(0)), dtype=bias_dtype) for m in m_splits] + + torch.ops.tex_ts.te_grouped_gemm_single_output_ts( + A, + A_scale_inv, + A_fp8_tensor_offset, + A_dtype, + True, # transa + B, + B_scale_inv, + B_fp8_tensor_offset, + B_dtype, + False, # transb + m_splits, + out[0], + 0 if out_offset is None else out_offset, + empty_tensor if out_offset is None else fp8_meta_tensor.scale, + out_dtype, + empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, + bias if use_bias else empty_tensors, + bias_dtype, + gelu_input, # this is pre_gelu_out + False, # grad + workspaces, + workspaces[0].shape[0], + accumulate, + use_split_accumulator, + ) return out, gelu_input diff --git a/transformer_engine/pytorch/cpp_extensions/normalization.py b/transformer_engine/pytorch/cpp_extensions/normalization.py index dd90bb0b66..50fd6b7709 100644 --- a/transformer_engine/pytorch/cpp_extensions/normalization.py +++ b/transformer_engine/pytorch/cpp_extensions/normalization.py @@ -4,8 +4,11 @@ """Python interface for normalization extensions""" from typing import Optional, Tuple, Union + import torch + import transformer_engine_torch as tex +from ._common import canonicalize_fp8_scales __all__ = [ @@ -23,46 +26,55 @@ def layernorm_fwd_fp8( weight: torch.Tensor, bias: torch.Tensor, eps: float, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, sm_margin: int, zero_centered_gamma: bool, ln_out: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """LayerNorm with FP8 output""" + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + ) + + # Launch kernel if ln_out is not None: return tex.layernorm_fwd_fp8_noalloc( inp, weight, bias, eps, - fp8_meta_tensor.scale, + fp8_scales["scale"], ln_out, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + fp8_scales["amax"], + fp8_scales["scale_inv"], otype, sm_margin, zero_centered_gamma, - scale_offset=int(fp8_tensor), - amax_offset=int(fp8_tensor), - scale_inv_offset=int(fp8_tensor), + **fp8_scales_offsets, ) - return tex.layernorm_fwd_fp8( inp, weight, bias, eps, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], otype, sm_margin, zero_centered_gamma, - scale_offset=int(fp8_tensor), - amax_offset=int(fp8_tensor), - scale_inv_offset=int(fp8_tensor), + **fp8_scales_offsets, ) @@ -71,26 +83,41 @@ def layernorm_fwd_fp8_inf( weight: torch.Tensor, bias: torch.Tensor, eps: float, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, sm_margin: int, zero_centered_gamma, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """LayerNorm with FP8 output. This version of layernorm_fwd_fp8 is specialized for inference, and returns only the normalized output. """ + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel ret = torch.ops.tex_ts.layernorm_fwd_fp8_inf_ts( inp, weight, bias, eps, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, sm_margin, zero_centered_gamma, @@ -121,44 +148,53 @@ def rmsnorm_fwd_fp8( inp: torch.Tensor, weight: torch.Tensor, eps: float, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, sm_margin: int, zero_centered_gamma: bool, rmsnorm_out: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """RMSNorm with FP8 output""" + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + ) + + # Launch kernel if rmsnorm_out is not None: return tex.rmsnorm_fwd_fp8_noalloc( inp, weight, eps, - fp8_meta_tensor.scale, + fp8_scales["scale"], rmsnorm_out, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + fp8_scales["amax"], + fp8_scales["scale_inv"], otype, sm_margin, zero_centered_gamma, - scale_offset=int(fp8_tensor), - amax_offset=int(fp8_tensor), - scale_inv_offset=int(fp8_tensor), + **fp8_scales_offsets, ) - return tex.rmsnorm_fwd_fp8( inp, weight, eps, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], otype, sm_margin, zero_centered_gamma, - scale_offset=int(fp8_tensor), - amax_offset=int(fp8_tensor), - scale_inv_offset=int(fp8_tensor), + **fp8_scales_offsets, ) @@ -166,25 +202,40 @@ def rmsnorm_fwd_fp8_inf( inp: torch.Tensor, weight: torch.Tensor, eps: float, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, sm_margin: int, zero_centered_gamma, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """RMSNorm with FP8 output. This version of rmsnorm_fwd_fp8 is specialized for inference, and returns only the normalized output. """ + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel ret = torch.ops.tex_ts.rmsnorm_fwd_fp8_inf_ts( inp, weight, eps, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, sm_margin, zero_centered_gamma, diff --git a/transformer_engine/pytorch/cpp_extensions/padding.py b/transformer_engine/pytorch/cpp_extensions/padding.py new file mode 100644 index 0000000000..41dfbe2466 --- /dev/null +++ b/transformer_engine/pytorch/cpp_extensions/padding.py @@ -0,0 +1,29 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Python interface for transpose extensions""" +from typing import List, Tuple, Union +import torch +import transformer_engine_torch as tex + + +__all__ = [ + "multi_padding_fused", +] + + +def multi_padding_fused( + inp: torch.Tensor, + row_list: List[int], + padded_row_list: List[int], + out: torch.Tensor, +) -> Union[Tuple[List[torch.Tensor], List[torch.Tensor]], None]: + """Padding""" + + tex.fused_multi_row_padding( + inp, + out, + row_list, + padded_row_list, + ) diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py index d96b743b9e..ddc3b67e9e 100644 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ b/transformer_engine/pytorch/cpp_extensions/transpose.py @@ -4,9 +4,12 @@ """Python interface for transpose extensions""" from typing import List, Optional, Tuple, Union + import torch + import transformer_engine_torch as tex from ..constants import TE_DType +from ._common import canonicalize_fp8_scales, empty_tensor __all__ = [ @@ -20,83 +23,115 @@ def fp8_cast_transpose_fused( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, cast_out: Optional[torch.Tensor] = None, transpose_out: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, noop_flag: Optional[torch.Tensor] = None, -) -> Union[Tuple[torch.Tensor, torch.Tensor], None]: +) -> Tuple[torch.Tensor, torch.Tensor]: """Cast + Transpose with FP8 output""" - return_outputs = False + # Allocate outputs if needed if transpose_out is None: transpose_out = torch.empty(inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8) - return_outputs = True if cast_out is None: cast_out = torch.empty_like(inp, dtype=torch.uint8) - return_outputs = True + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + ) + + # Construct no-op flag if needed if noop_flag is None: - noop_flag = torch.Tensor() + noop_flag = empty_tensor() + # Launch kernel if needed if inp.nelement() > 0: tex.fused_cast_transpose_noop( inp, noop_flag, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], cast_out, transpose_out, otype, - scale_offset=int(fp8_tensor), - amax_offset=int(fp8_tensor), - scale_inv_offset=int(fp8_tensor), + **fp8_scales_offsets, ) - if return_outputs: - return cast_out, transpose_out - return None + return cast_out, transpose_out def fp8_cast_transpose_bgrad_fused( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Cast + Transpose + BGRAD with FP8 output""" + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + ) + + # Launch kernel return tex.fused_cast_transpose_bgrad( inp, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], otype, - scale_offset=int(fp8_tensor), - amax_offset=int(fp8_tensor), - scale_inv_offset=int(fp8_tensor), + **fp8_scales_offsets, ) def fp8_transpose_bgrad_fused( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, grad_bias_type: torch.dtype, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Transpose + BGRAD with FP8 output""" + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + ) + + # Launch kernel return tex.fused_fp8_transpose_bgrad( inp, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], otype, TE_DType[grad_bias_type], - scale_offset=int(fp8_tensor), - amax_offset=int(fp8_tensor), - scale_inv_offset=int(fp8_tensor), + **fp8_scales_offsets, ) @@ -106,18 +141,30 @@ def fp8_cast_transpose_bgrad_dgelu_fused( fp8_meta_tensor: tex.FP8TensorMeta, fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Cast + Transpose + BGRAD + DGELU with FP8 output""" + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + ) + + # Launch kernel return tex.fused_cast_transpose_bgrad_dgelu( grad_output, gelu_input, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], otype, - scale_offset=int(fp8_tensor), - amax_offset=int(fp8_tensor), - scale_inv_offset=int(fp8_tensor), + **fp8_scales_offsets, ) @@ -128,6 +175,7 @@ def fp8_multi_cast_transpose_fused( amax_indices: List[int], scale_inv_indices: List[int], otype: tex.DType, + scale_inv: Optional[torch.Tensor] = None, ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: """Cast + Transpose with FP8 output""" @@ -135,7 +183,7 @@ def fp8_multi_cast_transpose_fused( input_list, fp8_meta_tensor.scale, fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + scale_inv if scale_inv is not None else fp8_meta_tensor.scale_inv, scale_indices, amax_indices, scale_inv_indices, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index aac693a430..04a1193a71 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -28,6 +28,8 @@ #include #include #include +#include +#include #include #include #include diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 05e4e97112..c797208e06 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -10,6 +10,26 @@ #include "common.h" #include "common/common.h" +/*************************************************************************************************** + * Permutation + **************************************************************************************************/ + +std::tuple> moe_permute_fwd( + at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices, + int64_t num_out_tokens, std::vector workspace, int64_t max_expanded_token_num); + +at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, + int64_t topK); + +at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, + int64_t topK); + +std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd, + const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob); + /*************************************************************************************************** * Attention **************************************************************************************************/ @@ -28,11 +48,13 @@ std::vector fused_attn_fwd_qkvpacked( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional scale_S, const c10::optional scale_O, - c10::optional amax_S, c10::optional amax_O, - const c10::optional Bias, const c10::optional rng_gen, - size_t rng_elts_per_thread); + const c10::optional descale_QKV, const int descale_QKV_offset, + const c10::optional descale_S, const int descale_S_offset, + const c10::optional scale_S, const int scale_S_offset, + const c10::optional scale_O, const int scale_O_offset, + c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, + const int amax_O_offset, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread); std::vector fused_attn_bwd_qkvpacked( size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, @@ -55,11 +77,13 @@ std::vector fused_attn_fwd_kvpacked( const at::Tensor KV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional scale_S, const c10::optional scale_O, - c10::optional amax_S, c10::optional amax_O, - const c10::optional Bias, const c10::optional rng_gen, - size_t rng_elts_per_thread); + const c10::optional descale_QKV, const int descale_QKV_offset, + const c10::optional descale_S, const int descale_S_offset, + const c10::optional scale_S, const int scale_S_offset, + const c10::optional scale_O, const int scale_O_offset, + c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, + const int amax_O_offset, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread); std::vector fused_attn_bwd_kvpacked( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, @@ -84,11 +108,13 @@ std::vector fused_attn_fwd( const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional scale_S, const c10::optional scale_O, - c10::optional amax_S, c10::optional amax_O, - const c10::optional Bias, const c10::optional rng_gen, - size_t rng_elts_per_thread); + const c10::optional descale_QKV, const int descale_QKV_offset, + const c10::optional descale_S, const int descale_S_offset, + const c10::optional scale_S, const int scale_S_offset, + const c10::optional scale_O, const int scale_O_offset, + c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, + const int amax_O_offset, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread); std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, @@ -139,6 +165,16 @@ void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count); +void te_grouped_gemm_single_output( + std::vector A, std::vector A_scale_inverse, int A_offset, + transformer_engine::DType A_type, bool transa, std::vector B, + at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb, + std::vector m_splits, at::Tensor D, int D_offset, at::Tensor D_scale, + transformer_engine::DType D_type, at::Tensor D_amax, std::vector bias, + transformer_engine::DType bias_type, std::vector pre_gelu_out, bool grad, + std::vector workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, int math_sm_count); + /*************************************************************************************************** * Transpose **************************************************************************************************/ @@ -315,13 +351,18 @@ at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, fl **************************************************************************************************/ at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax, - at::Tensor scale_inv, transformer_engine::DType otype); + at::Tensor scale_inv, transformer_engine::DType otype, + const int scale_offset = 0, const int amax_offset = 0, + const int scale_inv_offset = 0); void cast_to_fp8_noalloc(const at::Tensor &input, const at::Tensor &scale, at::Tensor output, - at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype); + at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, + const int scale_offset = 0, const int amax_offset = 0, + const int scale_inv_offset = 0); at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv, - transformer_engine::DType itype, transformer_engine::DType otype); + transformer_engine::DType itype, transformer_engine::DType otype, + const int scale_inv_offset = 0); /*************************************************************************************************** * Softmax @@ -455,4 +496,12 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, float momentum, float dampening, float lr, bool nesterov, bool first_run, bool wd_after_momentum, float scale); +/*************************************************************************************************** + * padding + **************************************************************************************************/ + +void fused_multi_row_padding(at::Tensor input, at::Tensor output, + std::vector input_row_list, + std::vector padded_input_row_list); + #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 50eb7b830f..b2968a688d 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -83,19 +83,33 @@ std::vector fused_attn_fwd_qkvpacked( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional scale_S, const c10::optional scale_O, - c10::optional amax_S, c10::optional amax_O, - const c10::optional Bias, const c10::optional rng_gen, - size_t rng_elts_per_thread) { + const c10::optional descale_QKV, const int descale_QKV_offset, + const c10::optional descale_S, const int descale_S_offset, + const c10::optional scale_S, const int scale_S_offset, + const c10::optional scale_O, const int scale_O_offset, + c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, + const int amax_O_offset, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread) { using namespace transformer_engine; auto qkv_sizes = QKV.sizes().vec(); std::vector qkv_shape{qkv_sizes.begin(), qkv_sizes.end()}; std::vector q_shape; - for (auto i : qkv_shape) { - if (i != 3) { - q_shape.push_back(i); + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + int loc_3 = 0; + switch (layout_group) { + case NVTE_3HD: + loc_3 = qkv_sizes.size() - 3; + break; + case NVTE_H3D: + loc_3 = qkv_sizes.size() - 2; + break; + default: + NVTE_ERROR("Invalid QKV layout group."); + } + for (auto it = qkv_shape.begin(); it != qkv_shape.end(); ++it) { + if (it - qkv_shape.begin() != loc_3) { + q_shape.push_back(*it); } } std::vector o_shape{q_shape.begin(), q_shape.end()}; @@ -122,11 +136,14 @@ std::vector fused_attn_fwd_qkvpacked( NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); } te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), - scale_O.value().data_ptr(), nullptr); + getDataPtr(descale_QKV.value(), descale_QKV_offset)); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, + getDataPtr(amax_S.value(), amax_S_offset), + getDataPtr(scale_S.value(), scale_S_offset), + getDataPtr(descale_S.value(), descale_S_offset)); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, + getDataPtr(amax_O.value(), amax_O_offset), + getDataPtr(scale_O.value(), scale_O_offset), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { O.fill_(0); @@ -247,9 +264,21 @@ std::vector fused_attn_bwd_qkvpacked( auto qkv_sizes = QKV.sizes().vec(); std::vector qkv_shape{qkv_sizes.begin(), qkv_sizes.end()}; std::vector q_shape; - for (auto i : qkv_shape) { - if (i != 3) { - q_shape.push_back(i); + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + int loc_3 = 0; + switch (layout_group) { + case NVTE_3HD: + loc_3 = qkv_sizes.size() - 3; + break; + case NVTE_H3D: + loc_3 = qkv_sizes.size() - 2; + break; + default: + NVTE_ERROR("Invalid QKV layout group."); + } + for (auto it = qkv_shape.begin(); it != qkv_shape.end(); ++it) { + if (it - qkv_shape.begin() != loc_3) { + q_shape.push_back(*it); } } auto h = q_shape[q_shape.size() - 2]; @@ -393,11 +422,13 @@ std::vector fused_attn_fwd_kvpacked( const at::Tensor KV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional scale_S, const c10::optional scale_O, - c10::optional amax_S, c10::optional amax_O, - const c10::optional Bias, const c10::optional rng_gen, - size_t rng_elts_per_thread) { + const c10::optional descale_QKV, const int descale_QKV_offset, + const c10::optional descale_S, const int descale_S_offset, + const c10::optional scale_S, const int scale_S_offset, + const c10::optional scale_O, const int scale_O_offset, + c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, + const int amax_O_offset, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread) { using namespace transformer_engine; auto q_sizes = Q.sizes().vec(); @@ -429,13 +460,16 @@ std::vector fused_attn_fwd_kvpacked( NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); } te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); + getDataPtr(descale_QKV.value(), descale_QKV_offset)); te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), - scale_O.value().data_ptr(), nullptr); + getDataPtr(descale_QKV.value(), descale_QKV_offset)); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, + getDataPtr(amax_S.value(), amax_S_offset), + getDataPtr(scale_S.value(), scale_S_offset), + getDataPtr(descale_S.value(), descale_S_offset)); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, + getDataPtr(amax_O.value(), amax_O_offset), + getDataPtr(scale_O.value(), scale_O_offset), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { O.fill_(0); @@ -747,11 +781,13 @@ std::vector fused_attn_fwd( const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional scale_S, const c10::optional scale_O, - c10::optional amax_S, c10::optional amax_O, - const c10::optional Bias, const c10::optional rng_gen, - size_t rng_elts_per_thread) { + const c10::optional descale_QKV, const int descale_QKV_offset, + const c10::optional descale_S, const int descale_S_offset, + const c10::optional scale_S, const int scale_S_offset, + const c10::optional scale_O, const int scale_O_offset, + c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, + const int amax_O_offset, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread) { using namespace transformer_engine; auto q_sizes = Q.sizes().vec(); @@ -788,15 +824,18 @@ std::vector fused_attn_fwd( NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); } te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); + getDataPtr(descale_QKV.value(), descale_QKV_offset)); te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); + getDataPtr(descale_QKV.value(), descale_QKV_offset)); te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, amax_O.value().data_ptr(), - scale_O.value().data_ptr(), nullptr); + getDataPtr(descale_QKV.value(), descale_QKV_offset)); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, + getDataPtr(amax_S.value(), amax_S_offset), + getDataPtr(scale_S.value(), scale_S_offset), + getDataPtr(descale_S.value(), descale_S_offset)); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, + getDataPtr(amax_O.value(), amax_O_offset), + getDataPtr(scale_O.value(), scale_O_offset), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { O.fill_(0); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cu b/transformer_engine/pytorch/csrc/extensions/cast.cu index c783c9d988..47f5825866 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cu +++ b/transformer_engine/pytorch/csrc/extensions/cast.cu @@ -6,8 +6,9 @@ #include "extensions.h" -at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax, - at::Tensor scale_inv, transformer_engine::DType otype) { +at::Tensor cast_to_fp8(const at::Tensor& input, const at::Tensor& scale, at::Tensor amax, + at::Tensor scale_inv, transformer_engine::DType otype, + const int scale_offset, const int amax_offset, const int scale_inv_offset) { using namespace transformer_engine; auto input_shape = input.sizes().vec(); std::vector shape{input_shape.begin(), input_shape.end()}; @@ -16,32 +17,45 @@ at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Ten if (input.numel() == 0) return output; + // Get pointers for FP8 scale, amax, scale-inverse + void* scale_dptr = getDataPtr(scale, scale_offset); + void* amax_dptr = getDataPtr(amax, amax_offset); + void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + auto input_cu = makeTransformerEngineTensor(input); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype, amax_dptr, + scale_dptr, scale_inv_dptr); nvte_fp8_quantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return output; } -void cast_to_fp8_noalloc(const at::Tensor &input, const at::Tensor &scale, at::Tensor output, - at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype) { +void cast_to_fp8_noalloc(const at::Tensor& input, const at::Tensor& scale, at::Tensor output, + at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, + const int scale_offset, const int amax_offset, + const int scale_inv_offset) { using namespace transformer_engine; size_t N = static_cast(input.size(0)); size_t H = static_cast(input.size(1)); + // Get pointers for FP8 scale, amax, scale-inverse + void* scale_dptr = getDataPtr(scale, scale_offset); + void* amax_dptr = getDataPtr(amax, amax_offset); + void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + auto input_cu = makeTransformerEngineTensor(input); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype, amax_dptr, + scale_dptr, scale_inv_dptr); nvte_fp8_quantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return; } -at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv, - transformer_engine::DType itype, transformer_engine::DType otype) { +at::Tensor cast_from_fp8(const at::Tensor& input, const at::Tensor& scale_inv, + transformer_engine::DType itype, transformer_engine::DType otype, + const int scale_inv_offset) { using namespace transformer_engine; auto input_shape = input.sizes().vec(); std::vector shape{input_shape.begin(), input_shape.end()}; @@ -49,7 +63,7 @@ at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv, auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); auto input_cu = makeTransformerEngineTensor(input.data_ptr(), shape, itype, nullptr, nullptr, - scale_inv.data_ptr()); + getDataPtr(scale_inv, scale_inv_offset)); auto output_cu = makeTransformerEngineTensor(output); nvte_fp8_dequantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cu b/transformer_engine/pytorch/csrc/extensions/gemm.cu index 01fb94cab4..ba9851e7e8 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cu +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cu @@ -115,6 +115,11 @@ void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int if (pre_gelu_out[i].data_ptr() != nullptr) pre_gelu_out[i].zero_(); continue; } + + NVTE_CHECK(A[i].is_contiguous(), "A[", i, "] must be contiguous."); + NVTE_CHECK(B[i].is_contiguous(), "B[", i, "] must be contiguous."); + NVTE_CHECK(D[i].is_contiguous(), "D[", i, "] must be contiguous."); + te_A.emplace_back(make_tensor( A[i].data_ptr(), {static_cast(A[i].size(0)), static_cast(A[i].size(1))}, A_type, nullptr, nullptr, getDataPtr(A_scale_inverse, A_offset + i))); @@ -146,3 +151,64 @@ void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int te_workspace.data(), accumulate, use_split_accumulator, math_sm_count, at::cuda::getCurrentCUDAStream()); } + +void te_grouped_gemm_single_output( + std::vector A, std::vector A_scale_inverse, int A_offset, + transformer_engine::DType A_type, bool transa, std::vector B, + at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb, + std::vector m_splits, at::Tensor D, int D_offset, at::Tensor D_scale, + transformer_engine::DType D_type, at::Tensor D_amax, std::vector bias, + transformer_engine::DType bias_type, std::vector pre_gelu_out, bool grad, + std::vector workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, int math_sm_count) { + using namespace transformer_engine; + std::vector te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace; + std::vector tensor_wrappers; + auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, + transformer_engine::DType dtype, void* amax_dptr, + void* scale_dptr, void* scale_inv_dptr) -> NVTETensor { + tensor_wrappers.emplace_back( + makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr)); + return tensor_wrappers.back().data(); + }; + NVTE_CHECK(D.is_contiguous(), "D must be contiguous."); + void* d_i_ptr = reinterpret_cast(D.data_ptr()); + for (size_t i = 0; i < A.size(); i++) { + if (m_splits[i] == 0) continue; + NVTE_CHECK(A[i].is_contiguous(), "A[", i, "] must be contiguous."); + NVTE_CHECK(B[i].is_contiguous(), "B[", i, "] must be contiguous."); + te_A.emplace_back(make_tensor( + A[i].data_ptr(), {static_cast(A[i].size(0)), static_cast(A[i].size(1))}, + A_type, nullptr, nullptr, getDataPtr(A_scale_inverse[i], A_offset))); + te_B.emplace_back(make_tensor( + B[i].data_ptr(), {static_cast(B[i].size(0)), static_cast(B[i].size(1))}, + B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i))); + te_D.emplace_back(make_tensor( + d_i_ptr, {static_cast(m_splits[i]), static_cast(A[i].size(0))}, D_type, + getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr)); + te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast(bias[i].size(0))}, + bias_type, nullptr, nullptr, nullptr)); + + const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr + ? std::vector{static_cast(pre_gelu_out[i].size(0))} + : std::vector{static_cast(pre_gelu_out[i].size(0)), + static_cast(pre_gelu_out[i].size(1))}; + te_pre_gelu_out.emplace_back(make_tensor( + pre_gelu_out[i].data_ptr(), gelu_shape, + GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr)); + // Move the D pointer to the next split. + char* char_ptr = reinterpret_cast(d_i_ptr); + char_ptr += m_splits[i] * A[i].size(0) * D.element_size(); + d_i_ptr = reinterpret_cast(char_ptr); + } + for (size_t i = 0; i < workspace.size(); i++) { + te_workspace.emplace_back(make_tensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte, + nullptr, nullptr, nullptr)); + } + + // For now, we only have multi-stream cublas backend. + nvte_multi_stream_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), + te_pre_gelu_out.data(), te_A.size(), transa, transb, grad, + te_workspace.data(), accumulate, use_split_accumulator, + math_sm_count, at::cuda::getCurrentCUDAStream()); +} diff --git a/transformer_engine/pytorch/csrc/extensions/padding.cu b/transformer_engine/pytorch/csrc/extensions/padding.cu new file mode 100644 index 0000000000..d975ebeeef --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/padding.cu @@ -0,0 +1,79 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "extensions.h" + +void fused_multi_row_padding(at::Tensor input, at::Tensor output, + std::vector input_row_list, + std::vector padded_input_row_list) { + using namespace transformer_engine; + + NVTE_CHECK(input_row_list.size() == padded_input_row_list.size(), + "Number of input row list and padded row list must match."); + NVTE_CHECK(input.dim() == 2, "Dimension of input must equal 2."); + NVTE_CHECK(output.dim() == 2, "Dimension of output must equal 2."); + + const int num_tensors = input_row_list.size(); + // Extract properties from PyTorch tensors + std::vector input_dptr_list, output_dptr_list; + std::vector> input_shape_list, output_shape_list; + std::vector input_type_list; + void* d_input_ptr = reinterpret_cast(input.data_ptr()); + void* d_output_ptr = reinterpret_cast(output.data_ptr()); + for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { + input_dptr_list.push_back(d_input_ptr); + output_dptr_list.push_back(d_output_ptr); + + // Move the input pointer to the next split. + char* input_char_ptr = reinterpret_cast(d_input_ptr); + const size_t input_dptr_offset = + input_row_list[tensor_id] * input.size(1) * input.element_size(); + input_char_ptr += input_dptr_offset; + d_input_ptr = reinterpret_cast(input_char_ptr); + + input_shape_list.push_back({input_row_list[tensor_id], static_cast(input.size(1))}); + input_type_list.push_back(GetTransformerEngineDType(input.scalar_type())); + + // Move the output pointer to the next split. + char* output_char_ptr = reinterpret_cast(d_output_ptr); + const size_t output_dptr_offset = + padded_input_row_list[tensor_id] * output.size(1) * output.element_size(); + output_char_ptr += output_dptr_offset; + d_output_ptr = reinterpret_cast(output_char_ptr); + + output_shape_list.push_back( + {padded_input_row_list[tensor_id], static_cast(output.size(1))}); + } + + // Construct TE tensors + std::vector nvte_input_list, nvte_output_list; + std::vector tensor_wrappers; + auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, + transformer_engine::DType dtype) -> NVTETensor { + tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype)); + return tensor_wrappers.back().data(); + }; + + std::vector padded_num_rows_list; + for (size_t i = 0; i < input_dptr_list.size(); ++i) { + if (input_dptr_list[i] == nullptr || input_row_list[i] == 0) continue; + nvte_input_list.emplace_back( + make_tensor(input_dptr_list[i], input_shape_list[i], input_type_list[i])); + nvte_output_list.emplace_back( + make_tensor(output_dptr_list[i], output_shape_list[i], input_type_list[i])); + padded_num_rows_list.emplace_back(padded_input_row_list[i]); + } + + // Check tensor lists + NVTE_CHECK(nvte_output_list.size() == nvte_input_list.size(), + "Number of input and output tensors must match"); + NVTE_CHECK(padded_num_rows_list.size() == nvte_input_list.size() && + "Number of input and padded row list must match"); + + // Launch TE kernel + nvte_multi_padding(nvte_input_list.size(), nvte_input_list.data(), nvte_output_list.data(), + padded_num_rows_list.data(), at::cuda::getCurrentCUDAStream()); +} diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cu b/transformer_engine/pytorch/csrc/extensions/permutation.cu new file mode 100644 index 0000000000..0c9bed45e0 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cu @@ -0,0 +1,170 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "extensions.h" + +std::tuple> moe_permute_fwd( + at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices, + int64_t num_out_tokens, std::vector workspace, int64_t max_expanded_token_num) { + const int num_tokens = input.size(0); + int num_cols = input.size(1); + const int topK = indices.size(1); + + // Initialize the workspace on the first run + if (workspace.empty()) { + auto options = + torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false); + + at::Tensor sorted_indices = torch::empty(max_expanded_token_num, options); + at::Tensor row_id = torch::range(0, max_expanded_token_num - 1, 1, options); + at::Tensor sorted_row_id = + torch::empty(max_expanded_token_num, + torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); + + size_t temp_storage_bytes = 0; + int *temp_ptr = nullptr; + cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_ptr, temp_ptr, temp_ptr, + temp_ptr, max_expanded_token_num); + at::Tensor temp_storage = torch::empty( + temp_storage_bytes, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); + + workspace.push_back(sorted_indices); + workspace.push_back(row_id); + workspace.push_back(sorted_row_id); + workspace.push_back(temp_storage); + } + + int *indices_ptr = reinterpret_cast(getDataPtr(indices, 0)); + int *sorted_indices_ptr = reinterpret_cast(getDataPtr(workspace[0], 0)); + int *row_id_ptr = reinterpret_cast(getDataPtr(workspace[1], 0)); + int *sorted_row_id_ptr = reinterpret_cast(getDataPtr(workspace[2], 0)); + + void *d_temp_storage = getDataPtr(workspace[3], 0); + size_t temp_storage_bytes = std::numeric_limits::max(); + + cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, indices_ptr, + sorted_indices_ptr, row_id_ptr, sorted_row_id_ptr, + num_tokens * topK); + + // Activations type + at::ScalarType _st; + if (dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2) + _st = at::ScalarType::Byte; + else + _st = input.scalar_type(); + + // Output buffer alloc + num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK; + at::Tensor permuted_output = torch::empty( + {num_out_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); + at::Tensor row_id_map = torch::empty( + {num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + auto input_cu = makeTransformerEngineTensor( + input.data_ptr(), {static_cast(input.size(0)), static_cast(num_cols)}, dtype); + auto permuted_output_cu = makeTransformerEngineTensor( + permuted_output.data_ptr(), + {static_cast(permuted_output.size(0)), static_cast(num_cols)}, dtype); + auto sorted_row_id_cu = + makeTransformerEngineTensor(sorted_row_id_ptr, {static_cast(num_tokens * topK)}, + transformer_engine::DType::kInt32); + auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); + + nvte_permute(input_cu.data(), permuted_output_cu.data(), sorted_row_id_cu.data(), + row_id_map_cu.data(), transformer_engine::TensorWrapper().data(), + transformer_engine::TensorWrapper().data(), + transformer_engine::TensorWrapper().data(), num_tokens, topK, num_cols, + num_out_tokens, stream); + + return std::make_tuple(permuted_output, row_id_map, workspace); +} + +at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, + int64_t topK) { + return moe_unpermute_fwd(input, dtype, row_id_map, prob, num_tokens, topK); +} + +at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, + int64_t topK) { + int num_cols = input.size(1); + + // Activations type + at::ScalarType _st; + if (dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2) + _st = at::ScalarType::Byte; + else + _st = input.scalar_type(); + + // Output buffer alloc + at::Tensor unpermuted_output = torch::empty( + {num_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + auto input_cu = makeTransformerEngineTensor( + input.data_ptr(), {static_cast(input.size(0)), static_cast(num_cols)}, dtype); + auto unpermuted_output_cu = makeTransformerEngineTensor( + unpermuted_output.data_ptr(), + {static_cast(unpermuted_output.size(0)), static_cast(num_cols)}, dtype); + auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); + auto prob_cu = makeTransformerEngineTensor(prob); + + nvte_unpermute(input_cu.data(), unpermuted_output_cu.data(), row_id_map_cu.data(), prob_cu.data(), + num_tokens, topK, num_cols, stream); + + return unpermuted_output; +} + +std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd, + const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob) { + const int topK = (prob.numel() > 0) ? prob.size(1) : 1; + const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0); + int num_cols = input_bwd.size(1); + + // Activations type + at::ScalarType _st; + if (dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2) + _st = at::ScalarType::Byte; + else + _st = input_bwd.scalar_type(); + + // Output buffer alloc + at::Tensor act_grad = torch::empty({input_fwd.size(0), num_cols}, + torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); + at::Tensor prob_grad = torch::empty( + {num_tokens, topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + auto input_bwd_cu = makeTransformerEngineTensor( + input_bwd.data_ptr(), {static_cast(input_bwd.size(0)), static_cast(num_cols)}, + dtype); + auto act_grad_cu = makeTransformerEngineTensor( + act_grad.data_ptr(), {static_cast(act_grad.size(0)), static_cast(num_cols)}, + dtype); + auto input_fwd_cu = makeTransformerEngineTensor( + input_fwd.data_ptr(), {static_cast(input_fwd.size(0)), static_cast(num_cols)}, + dtype); + auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); + auto prob_cu = makeTransformerEngineTensor(prob); + auto prob_grad_cu = makeTransformerEngineTensor(prob_grad); + + nvte_permute(input_bwd_cu.data(), act_grad_cu.data(), transformer_engine::TensorWrapper().data(), + row_id_map_cu.data(), prob_cu.data(), prob_grad_cu.data(), input_fwd_cu.data(), + num_tokens, topK, num_cols, 0, stream); + + return std::make_tuple(act_grad, prob_grad); +} diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 11b47ccdec..7bd5a2d8c8 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -10,6 +10,12 @@ #include "../extensions.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // Permutation functions + m.def("moe_permute_fwd", moe_permute_fwd); + m.def("moe_permute_bwd", moe_permute_bwd); + m.def("moe_unpermute_fwd", moe_unpermute_fwd); + m.def("moe_unpermute_bwd", moe_unpermute_bwd); + // Softmax functions m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD", py::call_guard()); @@ -87,10 +93,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc, "Fused Multi-tensor Cast + Transpose with allocating output tensors", py::call_guard()); - m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8", py::call_guard()); + m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8", py::call_guard(), + py::arg("input"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), + py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8", - py::call_guard()); - m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8", py::call_guard()); + py::call_guard(), py::arg("input"), py::arg("scale"), + py::arg("output"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), + py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); + m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8", py::call_guard(), + py::arg("input"), py::arg("scale_inv"), py::arg("itype"), py::arg("otype"), + py::arg("scale_inv_offset") = 0); m.def("te_gemm", &te_gemm, "CublasLt GEMM"); /// TODO Think m.def("te_grouped_gemm", &te_grouped_gemm, "Grouped GEMM"); m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked, @@ -140,7 +152,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction, "Update amax history and FP8 scale/scale_inv after reduction", py::call_guard()); - + m.def("fused_multi_row_padding", &fused_multi_row_padding, "Fused Multi-tensor padding", + py::call_guard()); // fused apply rope m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD", py::call_guard()); diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp index 8515092ae0..9f31dba669 100644 --- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp +++ b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp @@ -26,7 +26,7 @@ at::Tensor cast_to_fp8_ts(const at::Tensor &input, const at::Tensor &scale, at:: at::Tensor scale_inv, int64_t fp8_tensor, int64_t otype) { transformer_engine::DType otype_arg = reverse_map_dtype(otype); at::Tensor output = - cast_to_fp8(input, scale[fp8_tensor], amax[0][fp8_tensor], scale_inv[fp8_tensor], otype_arg); + cast_to_fp8(input, scale, amax, scale_inv, otype_arg, fp8_tensor, fp8_tensor, fp8_tensor); return output; } @@ -34,8 +34,8 @@ at::Tensor cast_to_fp8_noalloc_ts(const at::Tensor &input, const at::Tensor &sca at::Tensor output, at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor, int64_t otype) { transformer_engine::DType otype_arg = reverse_map_dtype(otype); - cast_to_fp8_noalloc(input, scale[fp8_tensor], output, amax[0][fp8_tensor], scale_inv[fp8_tensor], - otype_arg); + cast_to_fp8_noalloc(input, scale, output, amax, scale_inv, otype_arg, fp8_tensor, fp8_tensor, + fp8_tensor); return output; } @@ -43,7 +43,7 @@ at::Tensor cast_from_fp8_ts(const at::Tensor &input, const at::Tensor &scale_inv int64_t fp8_tensor, int64_t itype, int64_t otype) { transformer_engine::DType itype_arg = reverse_map_dtype(itype); transformer_engine::DType otype_arg = reverse_map_dtype(otype); - at::Tensor output = cast_from_fp8(input, scale_inv[fp8_tensor], itype_arg, otype_arg); + at::Tensor output = cast_from_fp8(input, scale_inv, itype_arg, otype_arg, fp8_tensor); return output; } @@ -305,6 +305,41 @@ std::vector te_grouped_gemm_ts( return D; } +at::Tensor te_grouped_gemm_single_output_ts( + std::vector A, std::vector A_scale_inverse, int64_t A_offset, + int64_t A_type, int64_t transa, std::vector B, at::Tensor B_scale_inverse, + int64_t B_offset, int64_t B_type, int64_t transb, std::vector m_splits, at::Tensor D, + int64_t D_offset, at::Tensor D_scale, int64_t D_type, at::Tensor D_amax, + std::vector bias, int64_t bias_type, std::vector pre_gelu_out, + int64_t grad, std::vector workspace, int64_t workspaceSize, int64_t accumulate, + int64_t use_split_accumulator) { + // cast inputs to types accepted by te_gemm + transformer_engine::DType A_type_arg = reverse_map_dtype(A_type); + bool transa_arg = static_cast(transa); + transformer_engine::DType B_type_arg = reverse_map_dtype(B_type); + bool transb_arg = static_cast(transb); + transformer_engine::DType D_type_arg = reverse_map_dtype(D_type); + transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type); + bool grad_arg = static_cast(grad); + size_t workspaceSize_arg = static_cast(workspaceSize); + bool accumulate_arg = static_cast(accumulate); + bool use_split_accumulator_arg = static_cast(use_split_accumulator); + + // Set an external SM Margin to all the GEMMs. + // This comes in handy when DP is overlapped with GEMMs + + const int device_id = at::cuda::current_device(); + const int sm_count = transformer_engine::cuda::sm_count(device_id); + int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); + + te_grouped_gemm_single_output(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B, + B_scale_inverse, B_offset, B_type_arg, transb_arg, m_splits, D, + D_offset, D_scale, D_type_arg, D_amax, bias, bias_type_arg, + pre_gelu_out, grad_arg, workspace, workspaceSize_arg, + accumulate_arg, use_split_accumulator_arg, num_math_sms); + return D; +} + at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, double eps, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor, @@ -371,6 +406,7 @@ TORCH_LIBRARY(tex_ts, m) { m.def("srelu_ts", &srelu_ts); m.def("te_gemm_ts", &te_gemm_ts); m.def("te_grouped_gemm_ts", &te_grouped_gemm_ts); + m.def("te_grouped_gemm_single_output_ts", &te_grouped_gemm_single_output_ts); m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts); m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts); m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts); diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index fdf65db21e..e9fb11e3b9 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -354,12 +354,8 @@ def backward( # Compute the forward pass. detached_inputs = detach_variable(inputs) - with ( - torch.enable_grad(), - ctx.recompute_ctx, - ctx.torch_gpu_amp_ctx, - ctx.torch_cpu_amp_ctx, - activation_recompute_forward(activation_recompute=True, recompute_phase=True), + with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward( + activation_recompute=True, recompute_phase=True ): outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) @@ -680,13 +676,9 @@ def checkpoint( torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx = _get_active_autocast_contexts() def recompute_fn(*args, **kwargs): - with ( - torch.autograd.enable_grad(), - te_recompute_ctx, - user_recompute_ctx, - torch_gpu_amp_forward_ctx, - torch_cpu_amp_forward_ctx, - ): + with torch.autograd.enable_grad(), ( + te_recompute_ctx + ), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx: function(*args, **kwargs) # Initialize a new checkpoint frame for each new forward pass. diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index b7f87ad397..c3d8709925 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -3,1024 +3,7 @@ # See LICENSE for license information. """Tensor class with FP8 data""" -from __future__ import annotations -from typing import Any, Dict, Optional, Tuple, Union -import warnings -import torch -from torch.utils._pytree import tree_map -import transformer_engine_torch as tex +from .tensor import Float8Tensor -from .constants import TE_DType -from .cpp_extensions import fp8_cast_transpose_fused -from .fp8 import FP8GlobalStateManager - -aten = torch.ops.aten -c10d = torch.ops.c10d -updated_fp8_params = {} - - -def _make_fp8_attr_property_funcs(name: str) -> Any: - """Make accessors for an FP8 attribute - - We store FP8 attributes in a dictionary so we can share them - between tensors with the same data, e.g. detached tensors. For - convenience, we also expose them as property attributes. This - function creates the accessors for property attributes. - - Parameters - ---------- - name: str - Key in dictionary of FP8 attributes - - """ - - def get_func(self) -> Any: - return self._fp8_attrs[name] - - def set_func(self, value: Any) -> None: - self._fp8_attrs[name] = value - - def del_func(self) -> None: - del self._fp8_attrs[name] - - return dict(fget=get_func, fset=set_func, fdel=del_func) - - -class _FromFloat8Func(torch.autograd.Function): - """Cast from FP8 to other dtype""" - - @staticmethod - def forward( - _ctx: torch.autograd.function.FunctionCtx, # unused - tensor: Float8Tensor, - dtype: Optional[torch.dtype] = None, - ) -> torch.Tensor: - if dtype is None: - dtype = tensor.dtype - data = tensor._data.contiguous().view(1, -1).detach() - out = tex.cast_from_fp8( - data, - tensor._scale_inv, - tensor._fp8_dtype, - TE_DType[dtype], - ) - out = out.view(tensor.size()) - return out - - @staticmethod - def backward( - _ctx: torch.autograd.function.FunctionCtx, # unused - grad: torch.Tensor, - ) -> Tuple[Optional[torch.Tensor], ...]: - # Assume that we want gradients in full precision - return grad, None - - -def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None: - """Amax scale and update when there is at least 1 trainable FP8 parameter.""" - param_id = id(param._data) - - if param_id not in FP8GlobalStateManager.fp8_param_to_autocast: - return - - autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id] - - if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params: - return - - if autocast_key in updated_fp8_params: - updated_fp8_params[autocast_key].add(param_id) - else: - updated_fp8_params[autocast_key] = {param_id} - - current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key] - # All FP8 trainable parameters have been updated. - if updated_fp8_params[autocast_key] == current_fp8_params_set: - FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True, fp8_weights=True) - del updated_fp8_params[autocast_key] - - -class _ToFloat8Func(torch.autograd.Function): - """Cast to FP8 from other dtype""" - - @staticmethod - def forward( - _ctx: torch.autograd.function.FunctionCtx, # unused - tensor: torch.Tensor, - fp8_meta: Optional[Dict[str, Any]] = None, - fp8_meta_forward: bool = True, - fp8_meta_index: Optional[int] = None, - fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, - ) -> Float8Tensor: - - # Manually compute scale-inverse if needed - if scale is not None and scale_inv is None: - if isinstance(scale, torch.Tensor): - scale_inv = scale.reciprocal() - else: - scale_inv = 1 / scale - - # Extract data from FP8 meta tensors if provided - if fp8_meta is not None: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=fp8_meta_forward, - ) - if fp8_meta_index is None: - raise ValueError( - "To initialize Float8Tensor with FP8 meta tensors, " - "the FP8 meta tensor index must also be provided" - ) - if scale is None: - scale = fp8_meta[fp8_meta_key].scale[fp8_meta_index] - if amax is None: - amax = fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] - if scale_inv is None: - scale_inv = fp8_meta[fp8_meta_key].scale_inv[fp8_meta_index] - scale_inv = scale_inv.detach().view(1).clone() - - # Check input tensor - tensor = tensor.contiguous().cuda().detach() - if tensor.dtype not in (torch.float32, torch.bfloat16, torch.float16): - tensor = tensor.float() - - # Check scale - if not isinstance(scale, torch.Tensor): - if scale is None: - scale = 1 - scale = torch.full( - [1], - scale, - dtype=torch.float32, - device=tensor.device, - ) - if scale.numel() != 1: - raise ValueError("Attempted to initialize Float8Tensor with invalid scale tensor") - scale = scale.to(device=tensor.device, dtype=torch.float32) - - # Check scale-inverse - if scale_inv is None: - scale_inv = scale.reciprocal() - scale_inv = scale_inv.to(device=tensor.device, dtype=torch.float32) - - # Check amax - if amax is None: - amax = torch.empty_like(scale) - if not (amax.numel() == 1 and amax.is_cuda and amax.dtype == torch.float32): - raise ValueError("Attempted to initialize Float8Tensor with invalid amax tensor") - - # Cast data to FP8 - data = tex.cast_to_fp8( - tensor.view(1, -1), - scale, - amax, - scale_inv, - fp8_dtype, - ) - data = data.view(tensor.size()) - - # Construct FP8 tensor - return Float8Tensor( - data=data, - fp8_meta=fp8_meta, - fp8_meta_forward=fp8_meta_forward, - fp8_meta_index=fp8_meta_index, - fp8_dtype=fp8_dtype, - fp8_scale_inv=scale_inv, - dtype=tensor.dtype, - ) - - @staticmethod - def backward( - _ctx: torch.autograd.function.FunctionCtx, # unused - grad: torch.Tensor, - ) -> Tuple[Optional[torch.Tensor], ...]: - # Assume that we want gradients in full precision - return grad, None, None, None, None, None, None, None - - -class _IdentityFunc(torch.autograd.Function): - """Identity function - - If constructor keyword-arguments are provided, then construct a - new Float8Tensor using the provided tensor's attributes. - - """ - - @staticmethod - def forward( - ctx, - tensor: Float8Tensor, - init_kwargs: Optional[Dict[str, Any]] = None, - ) -> torch.Tensor: - - # Return input tensor if constructor kwargs are not provided - ctx.input_dtype = tensor.dtype - if init_kwargs is None: - return tensor - - # Construct new tensor if constructor kwargs are provided - default_kwargs = dict( - data=tensor._data, - fp8_meta=tensor._fp8_meta, - fp8_meta_forward=tensor._fp8_meta_forward, - fp8_meta_index=tensor._fp8_meta_index, - fp8_dtype=tensor._fp8_dtype, - fp8_scale_inv=tensor._scale_inv, - dtype=tensor.dtype, - ) - for key, val in default_kwargs.items(): - if key not in init_kwargs: - init_kwargs[key] = val - return Float8Tensor(**init_kwargs) - - @staticmethod - def backward(ctx, grad): - return grad.to(ctx.input_dtype), None - - -class _ViewFunc(torch.autograd.Function): - """View function - - View the Float8Tensor using the provided shape. - - """ - - @staticmethod - def forward( - ctx, - tensor: torch.Tensor, - shape: Tuple[int] = None, - ) -> torch.Tensor: - - # Return input tensor if shape is not provided - ctx.shape = tensor.shape - if shape is None: - return tensor - - # Construct new tensor if shape is provided - if isinstance(tensor, Float8Tensor): - return Float8Tensor.make_like( - tensor, - data=tensor._data.view(*shape), - ) - return tensor.view(*shape) - - @staticmethod - def backward( - ctx, - grad: torch.Tensor, - ) -> Tuple[Union[torch.Tensor, None], ...]: - - if isinstance(grad, Float8Tensor): - dgrad = Float8Tensor.make_like( - grad, - data=grad._data.view(ctx.shape), - ) - return dgrad, None - return grad.view(ctx.shape), None - - -class _ReshapeFunc(torch.autograd.Function): - """Reshape function - - Reshape the Float8Tensor using the provided shape. - - """ - - @staticmethod - def forward( - ctx, - tensor: torch.Tensor, - shape: Tuple[int] = None, - ) -> torch.Tensor: - - # Return input tensor if shape is not provided - ctx.shape = tensor.shape - if shape is None: - return tensor - - # Construct new tensor if shape is provided - if isinstance(tensor, Float8Tensor): - return Float8Tensor.make_like( - tensor, - data=tensor._data.reshape(*shape), - ) - return tensor.reshape(*shape) - - @staticmethod - def backward( - ctx, - grad: torch.Tensor, - ) -> Tuple[Union[torch.Tensor, None], ...]: - - if isinstance(grad, Float8Tensor): - dgrad = Float8Tensor.make_like( - grad, - data=grad._data.reshape(ctx.shape), - ) - return dgrad, None - return grad.reshape(ctx.shape), None - - -class Float8Tensor(torch.Tensor): - """Experimental tensor class with FP8 data - - The tensor presents as having a standard, higher-precision dtype, - but the data itself is (scaled) FP8. For most tensor operations, - the data will be cast to the nominal dtype before performing the - operation. - - Parameters - ---------- - data: torch.Tensor - Raw FP8 data in a uint8 tensor - fp8_attrs: dict, optional - FP8 metadata, primarily managed by Float8Tensor. If - provided, all other FP8 configuration is ignored. - fp8_meta: dict, optional - FP8 metadata object, primarily managed by TE modules. - fp8_meta_forward: bool, default = `True` - Whether to access the FP8 metadata for the - forward pass. Ignored if fp8_meta is not - provided. - fp8_meta_index: int, optional - Index to access in FP8 meta tensors. Required if - fp8_meta is provided and otherwise ignored. - fp8_dtype: transformer_engine_torch.DType, tex.DType.kFloat8E4M3 - FP8 format. - fp8_scale_inv: torch.Tensor - Reciprocal of the scaling factor applied when - casting to FP8, i.e. the scaling factor that must - be applied when casting from FP8 to higher - precision. Can be inferred from fp8_meta if - provided. - dtype: torch.dtype, default = torch.float32 - Nominal tensor datatype. - - """ - - def __new__( - cls, - *, - data: torch.Tensor, - fp8_attrs: Optional[Dict[str, Any]] = None, - fp8_meta: Optional[Dict[str, Any]] = None, - fp8_meta_forward: bool = True, - fp8_meta_index: Optional[int] = None, - fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, - fp8_scale_inv: Optional[torch.Tensor] = None, - dtype: torch.dtype = torch.float32, - ): - - # Check that data buffer is valid - if data.element_size() != 1: - raise ValueError( - f"Float8Tensor requires data buffer with 8-bit dtype (got dtype={data.dtype})" - ) - if data.requires_grad: - raise ValueError("Float8Tensor requires non-differentiable data buffer") - if not data.is_cuda: - data = data.cuda() - - # Initialize tensor object - self = torch.Tensor._make_wrapper_subclass( - cls, - data.size(), - strides=data.stride(), - storage_offset=data.storage_offset(), - dtype=dtype, - layout=data.layout, - requires_grad=data.requires_grad, - device=data.device, - ) - self._data: torch.Tensor = data - - # Initialize dict of class attributes - # Note: We store FP8 attributes in a dictionary so we can - # share them between tensors with the same data, e.g. detached - # tensors. - self._fp8_attrs: dict = {} - if fp8_attrs is not None: - self._fp8_attrs = fp8_attrs - return self - - # FP8 meta tensors - if fp8_meta is not None and fp8_meta_index is None: - raise ValueError( - "To initialize Float8Tensor with FP8 meta tensors, " - "the FP8 meta tensor index must also be provided" - ) - self._fp8_meta: Optional[Dict[str, Any]] = fp8_meta - self._fp8_meta_forward: bool = fp8_meta_forward - self._fp8_meta_index: Optional[int] = fp8_meta_index - - # FP8 dtype - assert fp8_dtype in ( - tex.DType.kFloat8E4M3, - tex.DType.kFloat8E5M2, - ), f"Unsupported fp8_dtype {fp8_dtype}." - self._fp8_dtype: tex.DType = fp8_dtype - - # Transposed version of `_data`. - self._transpose: Optional[Float8Tensor] = None - self._transpose_invalid: bool = True - - # FP8 scale-inverse - if fp8_scale_inv is None and self._fp8_meta is not None: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=self._fp8_meta_forward, - ) - fp8_scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index] - fp8_scale_inv = fp8_scale_inv.detach().view(1).clone() - if fp8_scale_inv is None: - raise ValueError( - "Attempted to initialize Float8Tensor without specifying scale-inverse" - ) - if not isinstance(fp8_scale_inv, torch.Tensor): - fp8_scale_inv = torch.full( - [1], - fp8_scale_inv, - dtype=torch.float32, - device=self._data.device, - ) - if fp8_scale_inv.numel() != 1: - raise ValueError( - "Attempted to initialize Float8Tensor with invalid scale-inverse tensor" - ) - if fp8_scale_inv.dim() != 1: - fp8_scale_inv = fp8_scale_inv.reshape(1) - if fp8_scale_inv.device != self._data.device or fp8_scale_inv.dtype != torch.float32: - fp8_scale_inv = fp8_scale_inv.to( - device=self._data.device, - dtype=torch.float32, - ) - self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv - - return self - - @classmethod - def make_like( - cls, - tensor: Float8Tensor, - *, - data: torch.Tensor, - fp8_attrs: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> Float8Tensor: - """Use attributes of a Float8Tensor to create another Float8Tensor - - See constructor for list of keyword arguments. - - """ - default_kwargs = dict( - fp8_meta=tensor._fp8_meta, - fp8_meta_forward=tensor._fp8_meta_forward, - fp8_meta_index=tensor._fp8_meta_index, - fp8_dtype=tensor._fp8_dtype, - fp8_scale_inv=tensor._scale_inv, - dtype=tensor.dtype, - ) - for key, val in default_kwargs.items(): - if key not in kwargs: - kwargs[key] = val - return Float8Tensor(data=data, fp8_attrs=fp8_attrs, **kwargs) - - def __repr__(self): - return ( - "Float8Tensor(" - f"fp8_dtype={self._fp8_dtype}, " - f"scale_inv={self._scale_inv.item()}, " - f"data={self.from_float8(dtype=self.dtype)}" - ")" - ) - - def from_float8(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: - """ - Construct plain PyTorch tensor from Float8Tensor - - By default the resulting tensor's dtype is the - Float8Tensor's nominal dtype. - """ - return _FromFloat8Func.apply(self, dtype) - - @classmethod - def to_float8( - cls, - tensor: torch.Tensor, - *, - fp8_meta: Optional[Dict[str, Any]] = None, - fp8_meta_forward: bool = True, - fp8_meta_index: Optional[int] = None, - fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, - ): - """Construct Float8Tensor from plain PyTorch tensor""" - return _ToFloat8Func.apply( - tensor, - fp8_meta, - fp8_meta_forward, - fp8_meta_index, - fp8_dtype, - scale, - amax, - scale_inv, - ) - - def float(self) -> torch.Tensor: - return self.from_float8(dtype=torch.float32) - - def bfloat16(self) -> torch.Tensor: - return self.from_float8(dtype=torch.bfloat16) - - def half(self) -> torch.Tensor: - return self.from_float8(dtype=torch.float16) - - def cpu(self) -> torch.Tensor: - return self.from_float8().cpu() - - def clone(self) -> Float8Tensor: - return _IdentityFunc.apply(self, {"data": self._data.detach().clone()}) - - def view(self, *shape: Tuple[int]) -> Float8Tensor: - return _ViewFunc.apply(self, shape) - - def reshape(self, *shape: Tuple[int]) -> Float8Tensor: - return _ReshapeFunc.apply(self, shape) - - def expand_as(self, other: torch.Tensor): - if other is self: - # Note: expand_as is hackily used to create dummy autograd nodes - # and access the backward graph (see - # https://github.com/pytorch/pytorch/blob/238fb660851268f44ff88127887041fea352fe48/torch/nn/parallel/distributed.py#L1026). - # We equally hackily add a dummy function to handle this - # case. - return _IdentityFunc.apply(self) - return super().expand_as(other) - - def contiguous( - self, - *, - memory_format: torch.memory_format = torch.contiguous_format, - ) -> Float8Tensor: - """Returns tensor with data in provided memory format - - Returns `self` if data is already in correct memory format. - - """ - if self._data.is_contiguous(memory_format=memory_format): - return self - return _IdentityFunc.apply( - self, - {"data": self._data.detach().contiguous(memory_format=memory_format)}, - ) - - def transpose_2d( - self, - *, - force_compute: bool = False, - fill_cache: bool = False, - noop_flag: Optional[torch.Tensor] = None, - cache: Optional[bool] = None, - ) -> torch.Tensor: - """ - 2D transpose with caching support. - - Parameters - ---------- - force_compute: bool, default = `False` - Force computation of transpose. Otherwise use - cached values, if possible. - fill_cache: bool, default = `False` - Cache output tensor for future function calls. - noop_flag: torch.Tensor, optional - float32 flag indicating whether to avoid updating - cached values, if possible. - cache: bool, deprecated - - """ - assert self.dim() == 2, f"{self.dim()}-D transpose not supported." - - # Handle deprecated cache kwarg - if cache is not None: - msg = ( - "cache kwarg for Float8Tensor.transpose_2d is deprecated, " - "please use force_compute and fill_cache instead" - ) - warnings.warn(msg, DeprecationWarning) - if cache: - force_compute = False - fill_cache = True - else: - force_compute = True - fill_cache = False - - # Need to compute transpose if cache is invalid - need_compute = force_compute - if self._transpose is None: - need_compute = True - elif self._transpose_invalid: - need_compute = True - - # Need to apply transpose kernel if noop flag is applied - if noop_flag is not None: - need_compute = True - - # Return cached transpose if possible - if not need_compute: - return self._transpose - - # Allocate output if needed - data = self._data.contiguous().reshape(-1, self.size(-1)) - out = self._transpose - if out is None: - out = torch.empty( - (data.size(1), data.size(0)), - dtype=torch.uint8, - device=data.device, - ) - noop_flag = None - else: - self._transpose_invalid = False - - # Apply transpose kernel - fp8_dtype = self._fp8_dtype - if noop_flag is None: - tex.fp8_transpose_noalloc(data, out, fp8_dtype) - else: - noop_flag = noop_flag.to(dtype=torch.float32, device=data.device) - tex.fp8_transpose_noalloc_noop(data, out, noop_flag, fp8_dtype) - - # Fill cache if needed - if fill_cache: - self._transpose = out - self._transpose_invalid = False - - return out - - @torch.no_grad() - def cast_transpose_( - self, - tensor: torch.Tensor, - noop_flag: Optional[torch.Tensor] = None, - ) -> None: - """Cast from tensor and populate transpose cache - - Only supported for 2D tensors. - - Parameters - ---------- - tensor: torch.Tensor - Tensor to copy from. Must have same dimensions as - destination tensor. - noop_flag: torch.Tensor, optional - float32 flag indicating whether to avoid updating - destination tensor. - - """ - - # Make sure tensor is in expected format - data = self._data - if ( - tensor.device != data.device - or tensor.dtype not in (torch.float32, torch.float16, torch.bfloat16) - or not tensor.is_contiguous() - ): - dtype = tensor.dtype - if dtype not in (torch.float32, torch.float16, torch.bfloat16): - dtype = torch.float32 - tensor = tensor.to( - device=self.device, - dtype=dtype, - memory_format=torch.contiguous_format, - ) - if tensor.size() != data.size() or data.dim() != 2: - raise ValueError( - "Invalid tensor dimensions for FP8 cast-transpose " - f"(src={tuple(tensor.size())}, dst={tuple(data.size())})" - ) - if not data.is_contiguous(): - raise ValueError( - "FP8 cast-transpose is only supported for `Float8Tensor`s with contiguous data" - ) - if self._fp8_meta is None: - raise ValueError( - "FP8 cast-transpose is only supported for `Float8Tensor`s with FP8 metadata " - ) - - # Construct transpose cache if needed - transpose = self._transpose - if transpose is None or not transpose.is_contiguous(): - transpose = torch.empty( - (data.size(1), data.size(0)), - dtype=torch.uint8, - device=data.device, - ) - self._transpose = transpose - noop_flag = None - - # Launch cast-transpose kernel - fp8_meta_index = int(self._fp8_meta_index) - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=self._fp8_meta_forward, - ) - fp8_meta = self._fp8_meta[fp8_meta_key] - fp8_cast_transpose_fused( - tensor, - fp8_meta, - fp8_meta_index, - self._fp8_dtype, - cast_out=data, - transpose_out=transpose, - noop_flag=noop_flag, - ) - scale = fp8_meta.scale[fp8_meta_index : fp8_meta_index + 1] - scale_inv = self._scale_inv - if noop_flag is None: - torch.reciprocal(scale, out=scale_inv) - else: - torch.where( - noop_flag.bool(), - scale_inv, - scale.reciprocal(), - out=scale_inv, - ) - self._transpose_invalid = False - - @torch.no_grad() - def reset_fp8_meta_scale_inv(self) -> None: - """Replace FP8 meta tensor scale-inverse with cached value - - The FP8 meta tensor scale_inv entry corresponding to this - tensor is replaced with the scale_inv value used to construct - the tensor. - - """ - assert self._fp8_meta is not None, "FP8 meta tensors not found." - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=self._fp8_meta_forward, - ) - self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index].copy_(self._scale_inv[0]) - - def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: - """Create `Float8Tensor` with given nominal dtype - - The new tensor has the same underlying FP8 data. - - """ - return Float8Tensor.make_like( - self, - data=self._data, - fp8_attrs=self._fp8_attrs, - dtype=dtype, - ) - - def _reset_caches(self) -> None: - """ - Set transpose cache as invalid. - Should be called after any in-place operation. - """ - self._transpose_invalid = True - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): - - # In-place copy op - if func == aten.copy_.default: - - # Check tensors - dst = args[0] - src = args[1] - if not isinstance(dst, torch.Tensor): - raise RuntimeError("Attempted to copy into something that isn't a PyTorch tensor") - if not isinstance(src, torch.Tensor): - raise RuntimeError("Attempted to copy from something that isn't a PyTorch tensor") - - # Special handling based on which tensors are FP8 - dst_is_fp8 = isinstance(dst, Float8Tensor) - src_is_fp8 = isinstance(src, Float8Tensor) - if dst_is_fp8 and src_is_fp8: - - # Directly copy FP8 data if possible - if dst._fp8_dtype == src._fp8_dtype: - dst._data.copy_(src._data) - dst._scale_inv.copy_(src._scale_inv.detach()) - if dst._fp8_meta is not None: - if src._fp8_meta is None: - src_min, src_max = src.from_float8().aminmax() - src_amax = torch.maximum(-src_min, src_max) - else: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=src._fp8_meta_forward, - ) - fp8_meta_index = src._fp8_meta_index - src_amax = src._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=dst._fp8_meta_forward, - ) - fp8_meta_index = dst._fp8_meta_index - dst_amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] - torch.maximum(src_amax, dst_amax, out=dst_amax) - else: - dst.copy_(src.from_float8()) - - elif not dst_is_fp8 and src_is_fp8: - - # Cast source tensor to higher precision - dst.copy_(src.from_float8()) - - elif dst_is_fp8 and not src_is_fp8: - # Make sure input is in expected format - src = src.expand(dst.size()) - src = src.to( - device=dst.device, - memory_format=torch.contiguous_format, - ) - - # Update scaling factor if FP8 meta tensors are available - if dst._fp8_meta is None: - scale = dst._scale_inv.reciprocal() - amax = torch.empty_like(scale) - else: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=dst._fp8_meta_forward, - ) - fp8_meta_index = dst._fp8_meta_index - scale = dst._fp8_meta[fp8_meta_key].scale[fp8_meta_index] - amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] - dst._scale_inv.copy_(scale.detach().reciprocal()) - - # Cast to FP8 - if not dst._data.is_contiguous(): - raise RuntimeError("Transformer Engine cast kernels require contiguous data") - tex.cast_to_fp8_noalloc( - src.view(1, -1), - scale, - dst._data.view(1, -1), - amax, - dst._scale_inv, - dst._fp8_dtype, - ) - - # This branch is where the FP8 parameters are updated in-place during optimization. - # Handle forward amax reduction. - post_optimizer_step_fwd_amax_reduction(dst) - else: - - # Invalid case - raise RuntimeError("Using Float8Tensor copy logic, but no Float8Tensor found") - - # Nothing to return for in-place ops - if dst_is_fp8: - dst._reset_caches() - - return None - - # Slice op - if func == aten.slice.Tensor: - tensor = args[0] - data = tensor._data - data_slice = data.__torch_dispatch__( - func, - types, - [data] + list(args[1:]), - kwargs, - ) - return Float8Tensor.make_like(tensor, data=data_slice) - - # Detach op - if func == aten.detach.default: - # Simply return a new Float8Tensor with the same attrs - return Float8Tensor.make_like( - args[0], - data=args[0]._data, - fp8_attrs=args[0]._fp8_attrs, - ) - - # View op - if func == aten.view.default: - tensor = args[0] - data = tensor._data - data_view = data.__torch_dispatch__( - func, - types, - [data] + list(args[1:]), - kwargs, - ) - return Float8Tensor.make_like( - tensor, - data=data_view, - fp8_attrs=tensor._fp8_attrs, - ) - - def maybe_unwrap(t): - if isinstance(t, Float8Tensor): - return t.from_float8() - return t - - def maybe_update_inplace(arg, new_arg, schema_arg): - """Update values of FP8 tensors - - Keep the same FP8 scaling factors. - - """ - if ( - isinstance(arg, Float8Tensor) - and isinstance(new_arg, torch.Tensor) - and hasattr(schema_arg, "alias_info") - and hasattr(schema_arg.alias_info, "is_write") - and schema_arg.alias_info.is_write - ): - arg.copy_(new_arg) - arg._reset_caches() - - # In-place op - if func._schema.is_mutable: - # Cast to higher precision, perform op, and cast values - # back to original FP8 buffers - new_args = tree_map(maybe_unwrap, args) - new_kwargs = tree_map(maybe_unwrap, kwargs) - schema_args = func._schema.arguments - args_len = len(args) - out = super().__torch_dispatch__(func, types, new_args, new_kwargs) - for arg, new_arg, schema_arg in zip(args, new_args, schema_args): - maybe_update_inplace(arg, new_arg, schema_arg) - for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]): - assert kwarg == new_kwarg == schema_arg.name, "name of the kw argument should match" - maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg) - return None - - # Default op - # Note: cast to higher precision and perform op - args = tree_map(maybe_unwrap, args) - if kwargs is not None: - kwargs = tree_map(maybe_unwrap, kwargs) - out = super().__torch_dispatch__(func, types, args, kwargs) - return out - - @classmethod - def _make_in_reduce_ex( - cls, - data: torch.Tensor, - fp8_dtype: tex.DType, - fp8_scale_inv: torch.Tensor, - dtype: torch.dtype, - ) -> Float8Tensor: - """Build Float8Tensor, for use in __reduce__ - - __reduce_ex__ assumes object constructor has positional - arguments. - - """ - return Float8Tensor( - data=data, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv, - dtype=dtype, - ) - - def __reduce_ex__(self, protocol: int) -> tuple: - """Custom pickling to remove references to FP8 metadata objects""" - return ( - Float8Tensor._make_in_reduce_ex, - (self._data, self._fp8_dtype, self._scale_inv, self.dtype), - ) - - def _get_data(self) -> Float8Tensor: - """Get tensor data property""" - return super().data - - def _set_data(self, tensor: torch.Tensor) -> None: - """Set tensor data property - - Cast tensor to FP8 and store in FP8 buffer. - - """ - with torch.no_grad(): - self.copy_(tensor) - - # Cast to FP8 when setting Float8Tensor.data - data = property(_get_data, _set_data) - - # Accessors for objects in self._fp8_attrs - # Note: We store FP8 attributes in a dictionary so we can share - # them between tensors with the same data, e.g. detached tensors. - # For convenience, we also expose them as property attributes. - _fp8_meta = property(**_make_fp8_attr_property_funcs("fp8_meta")) - _fp8_meta_forward = property(**_make_fp8_attr_property_funcs("fp8_meta_forward")) - _fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index")) - _fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype")) - _transpose = property(**_make_fp8_attr_property_funcs("transpose")) - _transpose_invalid = property(**_make_fp8_attr_property_funcs("transpose_invalid")) - _scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv")) - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - return torch._C._disabled_torch_function_impl(func, types, args, kwargs) +__all__ = ["Float8Tensor"] diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index e15268b998..76679eb064 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -38,6 +38,15 @@ def get_default_fp8_recipe() -> DelayedScaling: return DelayedScaling() +def get_fp8_torch_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> torch.dtype: + """Get fp8 data type according to recipe and tensor""" + if fp8_recipe.fp8_format == Format.E4M3 or ( + fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor + ): + return torch.float8_e4m3fn + return torch.float8_e5m2fn + + def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType: """Get fp8 data type according to recipe and tensor""" if fp8_recipe.fp8_format == Format.E4M3 or ( diff --git a/transformer_engine/pytorch/module/__init__.py b/transformer_engine/pytorch/module/__init__.py index 6994f586b1..ba4755efe3 100644 --- a/transformer_engine/pytorch/module/__init__.py +++ b/transformer_engine/pytorch/module/__init__.py @@ -9,4 +9,6 @@ from .layernorm_mlp import LayerNormMLP from .layernorm import LayerNorm from .rmsnorm import RMSNorm +from .fp8_padding import Fp8Padding +from .fp8_unpadding import Fp8Unpadding from .base import initialize_ub, destroy_ub diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 281e3fe104..23a06e318f 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -52,6 +52,9 @@ def _apply_normalization( fwd_ln_sm_margin: int, zero_centered_gamma: bool, is_grad_enabled: bool, + fp8_scale: Optional[torch.Tensor] = None, + fp8_amax: Optional[torch.Tensor] = None, + fp8_scale_inv: Optional[torch.Tensor] = None, ): normalization_func = _get_normalization_func(normalization, fp8_out, is_grad_enabled, True) @@ -70,6 +73,9 @@ def _apply_normalization( fp8_dtype_forward, fwd_ln_sm_margin, zero_centered_gamma, + scale=fp8_scale, + amax=fp8_amax, + scale_inv=fp8_scale_inv, **output_kwarg, ) else: @@ -82,6 +88,9 @@ def _apply_normalization( fp8_dtype_forward, fwd_ln_sm_margin, zero_centered_gamma, + scale=fp8_scale, + amax=fp8_amax, + scale_inv=fp8_scale_inv, ), None, None, diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 3613e1fa5e..644af2c22c 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -138,15 +138,29 @@ def initialize_ub( ) if ifname is not None: - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - try: - hostname = socket.inet_ntoa( - fcntl.ioctl( - s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8")) - )[20:24] + # Make sure the ifname found in the environment is a valid network interface + if ifname in [name for _, name in socket.if_nameindex()]: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + hostname = socket.inet_ntoa( + fcntl.ioctl( + s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8")) + )[20:24] + ) + except OSError as err: + raise OSError(f"Invalid network interface: {ifname}") from err + finally: + s.close() + else: + ifname_warning = ( + f"'{ifname}' is not a valid network interface! `te.initialize_ub()` will" + " attempt to " + + "detect ranks on the same node by matching 'socket.gethostname()', which is " + + "known to fail on virtual clusters like Kubernetes. If Userbuffers " + + "initialization fails, please set the 'NVTE_UB_SOCKET_IFNAME' variable in " + + "your environment to the correct network interface." ) - except OSError as err: - raise OSError(f"Invalid network interface: {ifname}") from err + warnings.warn(ifname_warning, UserWarning) hostnames = [None for _ in range(world_size)] torch.distributed.all_gather_object(hostnames, hostname, world_group) @@ -865,11 +879,17 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # If primary weights are in fp8, wrap the parameter as Float8Tensor fp8_meta_index = self.param_init_meta[name].fp8_meta_index if self.primary_weights_in_fp8 and fp8_meta_index is not None: + dummy_amax = torch.empty( + (1, 1), + dtype=torch.float32, + device=param.device, + ) # Dummy buffer to avoid overwriting amax history param = Float8Tensor.to_float8( param, fp8_meta=self.fp8_meta, fp8_meta_index=fp8_meta_index, - amax=torch.empty(1, device="cuda"), # Dummy amax to avoid overwriting history. + amax=dummy_amax, + with_transpose_cache=torch.is_grad_enabled(), ) # Redo parameter wrap in case we broke it above @@ -891,7 +911,6 @@ def get_fp8_workspace( cache_name: Optional[str] = None, update_workspace: bool = True, skip_update_flag: Optional[torch.Tensor] = None, - with_transpose: bool = False, fsdp_group: dist_group_type = None, ) -> Float8Tensor: """Get FP8 workspace buffer and maybe update its values @@ -917,27 +936,30 @@ def get_fp8_workspace( skip_update_flag: torch.Tensor, optional GPU flag to skip updating the workspace. Take precedence over `update_workspace` if provided. - with_transpose: bool, default = `False` - Whether to initialize cached transpose in workspace. fsdp_group: bool, default = None FSDP process group that the weights are distributed over. """ - # Construct workspace if needed + # Try getting workspace from cache out = None if cache_name is not None: out = self._fp8_workspaces.get(cache_name, None) - # Gather cached Fp8 workspace if it's distributed - # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work - # for models initialized with Fp8 primary weights. - if ( - not isinstance(out, Float8Tensor) - and fsdp_group is not None - and out._data.shape != tensor.data.shape - ): - _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out) + # Gather cached Fp8 workspace if it's distributed + # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work + # for models initialized with Fp8 primary weights. + if ( + out is not None + and not isinstance(out, Float8Tensor) + and fsdp_group is not None + and out._data.shape != tensor.data.shape + ): + _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out) + + # Construct workspace if needed if out is None: + + # FP8 data if tensor is None or fp8_meta_forward is None or fp8_meta_index is None: raise ValueError( "tensor, fp8_meta_forward, and fp8_meta_index kwargs " @@ -947,16 +969,38 @@ def get_fp8_workspace( self.fp8_meta["recipe"], fprop_tensor=fp8_meta_forward, ) + data = torch.empty_like(tensor, dtype=torch.uint8) scale_inv = torch.empty([1], dtype=torch.float32, device=tensor.device) + + # Transpose cache + with_transpose_cache = torch.is_grad_enabled() + if ( + not with_transpose_cache + and is_fp8_activation_recompute_enabled() + and not in_fp8_activation_recompute_phase() + ): + with_transpose_cache = True + data_transpose = None + if with_transpose_cache: + data_transpose = torch.empty( + (tensor.size(-1), tensor.numel() // tensor.size(-1)), + dtype=torch.uint8, + device=tensor.device, + ) + + # Construct FP8 tensor out = Float8Tensor( - data=torch.empty_like(tensor, dtype=torch.uint8), + data=data, fp8_meta=self.fp8_meta, fp8_meta_forward=fp8_meta_forward, fp8_meta_index=fp8_meta_index, fp8_dtype=fp8_dtype, fp8_scale_inv=scale_inv, dtype=tensor.dtype, + data_transpose=data_transpose, ) + + # Update cache if cache_name is not None: self._fp8_workspaces[cache_name] = out update_workspace = True @@ -968,33 +1012,17 @@ def get_fp8_workspace( if update_workspace: if tensor is None: raise ValueError("tensor kwarg must be provided to update FP8 workspace") - if with_transpose: - out.cast_transpose_( - tensor, - noop_flag=skip_update_flag, - ) + if is_in_onnx_export_mode(): + # ONNX export does not support fused cast-transpose + # kernel and requires that FP8 scales can be + # represented with constant ops. + transpose_cache = out._transpose + out._transpose = None + out.quantize_(tensor) + out._scale_inv.fill_(out._scale_inv.item()) + out._transpose = transpose_cache else: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=out._fp8_meta_forward, - ) - fp8_meta = out._fp8_meta[fp8_meta_key] - fp8_meta_index = out._fp8_meta_index - cast_to_fp8( - tensor, - fp8_meta, - fp8_meta_index, - out._fp8_dtype, - out=out._data, - ) - if is_in_onnx_export_mode(): - # ONNX export expects FP8 scales can be - # represented with constant ops. However, copying - # into a buffer involves an expand op for array - # broadcasting. We work around this by filling the - # buffer instead. - out._scale_inv.fill_(fp8_meta.scale_inv[fp8_meta_index].item()) - else: - out._scale_inv.copy_(fp8_meta.scale_inv[fp8_meta_index]) + out.quantize_(tensor, noop_flag=skip_update_flag) return out diff --git a/transformer_engine/pytorch/module/fp8_padding.py b/transformer_engine/pytorch/module/fp8_padding.py new file mode 100644 index 0000000000..60bac91353 --- /dev/null +++ b/transformer_engine/pytorch/module/fp8_padding.py @@ -0,0 +1,123 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""FP8 Padding API""" + +from typing import Union, List + +import torch + +from ..cpp_extensions import ( + multi_padding_fused, +) +from ..jit import no_torch_dynamo + + +__all__ = ["Fp8Padding"] + + +class _Fp8Padding(torch.autograd.Function): + """functional FP8 padding""" + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + m_splits: List[int], + padded_m_splits: List[int], + is_grad_enabled: bool, + ) -> torch.Tensor: + # Make sure input dimensions are compatible + in_features = inp.shape[-1] + + # Allocate cast and transpose output tensor + total_row = sum(padded_m_splits) + out = torch.empty([total_row, in_features], dtype=inp.dtype, device=inp.device) + + multi_padding_fused(inp.view(-1, in_features), m_splits, padded_m_splits, out) + + if is_grad_enabled: + ctx.m_splits = m_splits + ctx.padded_m_splits = padded_m_splits + ctx.requires_dgrad = inp.requires_grad + + return out + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + + grad_input = None + if ctx.requires_dgrad: + grad_output = grad_output.contiguous() + + grad_output_mats = torch.split( + grad_output.view(-1, grad_output.shape[-1]), ctx.padded_m_splits + ) + grad_input = torch.cat( + [ + grad_output_mat[: ctx.m_splits[i]] + for i, grad_output_mat in enumerate(grad_output_mats) + ], + dim=0, + ) + + return (grad_input, None, None, None) + + +class Fp8Padding(torch.nn.Module): + """ + Apply the padding for Grouped GEMM input. + + Parameters + ---------- + num_gemms: int + number of GEMMs to be performed simutaneously. + """ + + def __init__( + self, + num_gemms, + ) -> None: + super().__init__() + + self.num_gemms = num_gemms + + @no_torch_dynamo() + def forward( + self, + inp: torch.Tensor, + m_splits: List[int], + ) -> Union[torch.Tensor, List[int]]: + """ + Apply the padding to the input. + + Parameters + ---------- + inp : torch.Tensor + Input tensor. + m_splits : List[int] + List of integers representing the split of the input tensor. + """ + + assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." + + # FP8 padding calculate + padded_m_splits = [(m + 15) // 16 * 16 for m in m_splits] + + if torch.is_grad_enabled(): + fn = _Fp8Padding.apply + args = [] + else: + fn = _Fp8Padding.forward + args = [None] + + args += ( + inp, + m_splits, + padded_m_splits, + torch.is_grad_enabled(), + ) + out = fn(*args) + + return out, padded_m_splits diff --git a/transformer_engine/pytorch/module/fp8_unpadding.py b/transformer_engine/pytorch/module/fp8_unpadding.py new file mode 100644 index 0000000000..6e08f849ef --- /dev/null +++ b/transformer_engine/pytorch/module/fp8_unpadding.py @@ -0,0 +1,119 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""FP8 Padding API""" + +from typing import List + +import torch + +from ..cpp_extensions import ( + multi_padding_fused, +) +from ..jit import no_torch_dynamo + + +__all__ = ["Fp8Unpadding"] + + +class _Fp8Unpadding(torch.autograd.Function): + """functional FP8 unpadding""" + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + m_splits: List[int], + padded_m_splits: List[int], + is_grad_enabled: bool, + ) -> torch.Tensor: + inputmats = torch.split(inp.view(-1, inp.shape[-1]), padded_m_splits) + out_ret = torch.cat( + [grad_output_mat[: m_splits[i]] for i, grad_output_mat in enumerate(inputmats)], dim=0 + ) + + if is_grad_enabled: + ctx.m_splits = m_splits + ctx.padded_m_splits = padded_m_splits + ctx.requires_dgrad = inp.requires_grad + + return out_ret + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + grad_input = None + if ctx.requires_dgrad: + grad_output = grad_output.contiguous() + + in_features = grad_output.shape[-1] + + # Allocate cast and transpose output tensor + total_row = sum(ctx.padded_m_splits) + grad_input = torch.empty( + [total_row, in_features], dtype=grad_output.dtype, device=grad_output.device + ) + # FP8 pad input for forward, FP8 input transpose for backward wgrad + multi_padding_fused( + grad_output.view(-1, in_features), ctx.m_splits, ctx.padded_m_splits, grad_input + ) + + return (grad_input, None, None, None) + + +class Fp8Unpadding(torch.nn.Module): + """ + Apply the unpadding for Grouped GEMM input. + + Parameters + ---------- + num_gemms: int + number of GEMMs to be performed simutaneously. + """ + + def __init__( + self, + num_gemms, + ) -> None: + super().__init__() + + self.num_gemms = num_gemms + + @no_torch_dynamo() + def forward( + self, + inp: torch.Tensor, + m_splits: List[int], + ) -> torch.Tensor: + """ + Apply the unpadding to the input. + + Parameters + ---------- + inp : torch.Tensor + Input tensor. + m_splits : List[int] + List of integers representing the split of the input tensor. + """ + + assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." + + # FP8 padding calculate + padded_m_splits = [(m + 15) // 16 * 16 for m in m_splits] + + if torch.is_grad_enabled(): + fn = _Fp8Unpadding.apply + args = [] + else: + fn = _Fp8Unpadding.forward + args = [None] + + args += ( + inp, + m_splits, + padded_m_splits, + torch.is_grad_enabled(), + ) + out = fn(*args) + + return out diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index a91ff5c361..14edd64249 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -28,8 +28,6 @@ from ..distributed import ( set_tensor_model_parallel_attributes, get_distributed_world_size, - is_fp8_activation_recompute_enabled, - in_fp8_activation_recompute_phase, ) from ..cpp_extensions import ( cast_to_fp8, @@ -42,21 +40,10 @@ from ..jit import no_torch_dynamo from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor +from ..export import is_in_onnx_export_mode __all__ = ["GroupedLinear"] -""" -The offset for fp8_meta_index. -_GEMM_INPUT = 0 -_GEMM_WEIGHT = num_gemms -_GEMM_OUTPUT = 2 * num_gemms -Must be properly set in GroupedLinear's initialization. -""" -_GEMM_INPUT = 0 -_GEMM_WEIGHT = 0 -_GEMM_OUTPUT = 0 -_GRAD_OUTPUT = 0 - class _GroupedLinear(torch.autograd.Function): """GroupedLinear semi-top level module @@ -75,12 +62,9 @@ def forward( fp8_meta: Dict[str, Any], fuse_wgrad_accumulation: bool, cpu_offloading: bool, - tp_group: Union[dist_group_type, None], - tp_size: int, sequence_parallel: bool, - tensor_parallel: bool, activation_dtype: torch.dtype, - parallel_mode: Union[str, None], + fp8_meta_offsets: Dict[str, int], is_grad_enabled: bool, weights_fp8: List[Union[Float8Tensor, None]], *weights_and_biases: Union[Float8Tensor, torch.Tensor, None], @@ -102,10 +86,11 @@ def forward( inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats] inputmats = [] inputmats_t = [] + inputmat_scale_inv = None - global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + inputmat_scale_inv = torch.empty([num_gemms], dtype=torch.float32, device=inp.device) if ( not fp8_meta["recipe"].override_linear_precision.wgrad and is_grad_enabled @@ -113,7 +98,9 @@ def forward( and not sequence_parallel ): # FP8 input for forward, FP8 input transpose for backward wgrad - indices = list(range(_GEMM_INPUT, _GEMM_INPUT + num_gemms)) + indices = list( + range(fp8_meta_offsets["input"], fp8_meta_offsets["input"] + num_gemms) + ) inputmats, inputmats_t = fp8_multi_cast_transpose_fused( inputmats_no_fp8, fp8_meta["scaling_fwd"], @@ -121,6 +108,7 @@ def forward( indices, # amax_indices indices, # scale_inv_indices fp8_dtype_forward, + scale_inv=inputmat_scale_inv, ) else: # FP8 input for forward @@ -128,11 +116,24 @@ def forward( cast_to_fp8( inputmats_no_fp8[i], fp8_meta["scaling_fwd"], - _GEMM_INPUT + i, + fp8_meta_offsets["input"] + i, fp8_dtype_forward, + scale_inv=inputmat_scale_inv, ) for i in range(num_gemms) ] + + # Hack for ONNX export + # Note: ONNX models are represented as a graph of tensor + # operations, so the in-place scale-inv update doesn't fit + # very well. We work around this by making it look like + # the scale-inv tensor is initialized with a copy. + # Note: ONNX export expects FP8 scales can be represented + # with constant ops. However, copying into a buffer + # involves an expand op for array broadcasting. We work + # around this by filling the buffer instead. + if is_in_onnx_export_mode(): + inputmat_scale_inv.fill_(inputmat_scale_inv.item()) else: inputmats = inputmats_no_fp8 @@ -153,16 +154,17 @@ def forward( _ = fp8_grouped_gemm( [w._data for w in weights_fp8], - fp8_meta["scaling_fwd"].scale_inv, - _GEMM_WEIGHT, + [w._scale_inv for w in weights_fp8], + 0, # weight offset is 0 for the newly created _scale_inv fp8_dtype_forward, inputmats, - fp8_meta["scaling_fwd"].scale_inv, - _GEMM_INPUT, + inputmat_scale_inv, + 0, fp8_dtype_forward, - torch.split(out, m_splits), + [out], activation_dtype, get_multi_stream_cublas_workspace(), + m_splits=m_splits, bias=biases, use_bias=use_bias, use_split_accumulator=_2X_ACC_FPROP, @@ -178,14 +180,14 @@ def forward( for i in range(num_gemms): # amax of input amin, amax = inputmats[i].aminmax() - fp8_meta["scaling_fwd"].amax_history[0][_GEMM_INPUT + i] = torch.max( - -amin, amax - ).float() + fp8_meta["scaling_fwd"].amax_history[0][fp8_meta_offsets["input"] + i] = ( + torch.max(-amin, amax).float() + ) # amax of weight amin, amax = weights[i].aminmax() - fp8_meta["scaling_fwd"].amax_history[0][_GEMM_WEIGHT + i] = torch.max( - -amin, amax - ).float() + fp8_meta["scaling_fwd"].amax_history[0][fp8_meta_offsets["weight"] + i] = ( + torch.max(-amin, amax).float() + ) out = torch.empty( [sum(m_splits), weights[0].size(0)], @@ -230,7 +232,7 @@ def forward( t.activation_offloading = True ctx.save_for_backward( - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, + inputmat_scale_inv, *saved_inputmats, *saved_inputmats_t, *weights, @@ -250,11 +252,8 @@ def forward( ctx.is_first_microbatch = is_first_microbatch ctx.use_bias = use_bias ctx.sequence_parallel = sequence_parallel - ctx.tensor_parallel = tensor_parallel ctx.inp_shape = inp.shape - ctx.parallel_mode = parallel_mode - ctx.tp_group = tp_group - ctx.tp_size = tp_size + ctx.fp8_meta_offsets = fp8_meta_offsets ctx.requires_dgrad = inp.requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False if ctx.fp8 and requires_grad(inp, weights[0], biases[0]): @@ -270,7 +269,7 @@ def forward( def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: with torch.cuda.nvtx.range("_GroupedLinear_backward"): ( - fwd_scale_inverses, + inputmat_scale_inv, *saved_tensors, ) = ctx.saved_tensors inputmats = saved_tensors[: ctx.num_gemms] @@ -284,7 +283,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], w.main_grad = main_grads[i] weights[i] = w - global _GEMM_INPUT, _GEMM_WEIGHT, _GRAD_OUTPUT # preprocess grad_output grad_output = grad_output.contiguous() grad_output_mats = torch.split( @@ -302,13 +300,18 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], fp8_cast_transpose_bgrad_fused( grad_output_mats[i], ctx.fp8_meta["scaling_bwd"], - _GRAD_OUTPUT + i, + ctx.fp8_meta_offsets["grad_output"] + i, fp8_dtype_backward, ) ) else: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - indices = list(range(_GRAD_OUTPUT, _GRAD_OUTPUT + ctx.num_gemms)) + indices = list( + range( + ctx.fp8_meta_offsets["grad_output"], + ctx.fp8_meta_offsets["grad_output"] + ctx.num_gemms, + ) + ) grad_output_c, grad_output_t = fp8_multi_cast_transpose_fused( grad_output_mats, ctx.fp8_meta["scaling_bwd"], @@ -322,7 +325,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output_c[i] = cast_to_fp8( grad_output_mats[i], ctx.fp8_meta["scaling_bwd"], - _GRAD_OUTPUT + i, + ctx.fp8_meta_offsets["grad_output"] + i, fp8_dtype_backward, ) @@ -342,18 +345,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) fp8_grouped_gemm( [w.transpose_2d() for w in weights_fp8], - torch.cat( - [w._scale_inv for w in weights_fp8] - ), # avoiding torch.cat requires another interface + [w._scale_inv for w in weights_fp8], 0, # weight offset is 0 for the newly created _scale_inv weights_fp8[0]._fp8_dtype, grad_output_c, ctx.fp8_meta["scaling_bwd"].scale_inv, - _GRAD_OUTPUT, + ctx.fp8_meta_offsets["grad_output"], fp8_dtype_backward, - torch.split(dgrad, ctx.m_splits), + [dgrad], ctx.activation_dtype, get_multi_stream_cublas_workspace(), + m_splits=ctx.m_splits, use_split_accumulator=_2X_ACC_DGRAD, ) else: @@ -396,12 +398,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inp._data if isinstance(inp, Float8Tensor) else inp for inp in inputmats_t ], - fwd_scale_inverses, - _GEMM_INPUT, + [inputmat_scale_inv], + 0, fp8_dtype_forward, grad_output_t, ctx.fp8_meta["scaling_bwd"].scale_inv, - _GRAD_OUTPUT, + ctx.fp8_meta_offsets["grad_output"], fp8_dtype_backward, wgrad_list, ctx.activation_dtype, @@ -482,12 +484,9 @@ def handle_custom_ddp_from_mcore(w, wgrad): None, # fp8_meta None, # fuse_wgrad_accumulation None, # cpu_offloading - None, # tp_group - None, # tp_size None, # sequence_parallel - None, # tensor_parallel None, # activation_dtype - None, # parallel_mode + None, # fp8_meta_offsets None, # is_grad_enabled None, # weights_fp8 *wgrad_list, @@ -513,31 +512,14 @@ class GroupedLinear(TransformerEngineBaseModule): used for initializing weights in the following way: `init_method(weight)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. get_rng_state_tracker : Callable, default = `None` - used to get the random number generator state tracker for initilizeing weights. + used to get the random number generator state tracker for initializing weights. rng_tracker_name : str, default = `None` the param passed to get_rng_state_tracker to get the specific rng tracker. device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will allocated. It is the user's + The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. - Parallelism parameters - ---------------------- - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - tp_group : ProcessGroup, default = `None` - tensor parallel process group. - tp_size : int, default = 1 - used as TP (tensor parallel) world size when TP groups are not formed during - initialization. In this case, users must call the - `set_tensor_parallel_group(tp_group)` method on the initialized module before the - forward pass to supply the tensor parallel group needed for tensor and sequence - parallel collectives. - parallel_mode : {None, 'Column', 'Row'}, default = `None` - used to decide whether this GroupedLinear layer is Column Parallel Linear or Row - Parallel Linear as described `here `_. - When set to `None`, no communication is performed. - Optimization parameters ----------------------- fuse_wgrad_accumulation : bool, default = 'False' @@ -598,8 +580,7 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name - global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT - _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT = 0, num_gemms, 2 * num_gemms + self._offsets = {"input": 0, "weight": num_gemms, "output": 2 * num_gemms, "grad_output": 0} if tp_group is None: self.tp_size = tp_size @@ -636,7 +617,7 @@ def __init__( ), init_fn=init_method, get_rng_state_tracker=get_rng_state_tracker, - fp8_meta_index=_GEMM_WEIGHT + i, + fp8_meta_index=self._offsets["weight"] + i, ) # Construct bias parameters if needed @@ -743,22 +724,12 @@ def forward( weight_tensors_fp8 = [None] * self.num_gemms if self.fp8: - with_transpose = torch.is_grad_enabled() - if ( - not with_transpose - and is_fp8_activation_recompute_enabled() - and not in_fp8_activation_recompute_phase() - ): - with_transpose = True for i in range(self.num_gemms): if isinstance(weight_tensors[i], Float8Tensor): - # Fill transpose cache in FP8 tensor if needed - update_transpose_cache = with_transpose - if update_transpose_cache: - update_transpose_cache = ( - is_first_microbatch or skip_fp8_weight_update is not None - ) - if update_transpose_cache: + # Make sure transpose cache is valid, if present + # Note: Transpose cache may have been invalidated + # externally, e.g. by optimizer. + if weight_tensors[i]._transpose is not None: weight_tensors[i].transpose_2d( fill_cache=True, noop_flag=skip_fp8_weight_update, @@ -769,11 +740,10 @@ def forward( weight_tensors_fp8[i] = self.get_fp8_workspace( tensor=weight_tensors[i], fp8_meta_forward=True, - fp8_meta_index=_GEMM_WEIGHT + i, + fp8_meta_index=self._offsets["weight"] + i, cache_name=(None if is_first_microbatch is None else f"weight{i}"), update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, - with_transpose=with_transpose, ) from ..cpu_offload import CPUOffloadEnabled @@ -794,12 +764,9 @@ def forward( self.fp8_meta, self.fuse_wgrad_accumulation, CPUOffloadEnabled, - self.tp_group, - self.tp_size, self.sequence_parallel, - self.tp_size > 1, self.activation_dtype, - self.parallel_mode, + self._offsets, torch.is_grad_enabled(), weight_tensors_fp8, *weight_tensors, diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index ec33ad2033..292fcd06de 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -110,7 +110,7 @@ class LayerNorm(torch.nn.Module): y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will allocated. It is the user's + The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. """ diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 10560cdad6..92030a7f7a 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -36,8 +36,6 @@ allreduce, reduce_scatter_along_first_dim, gather_along_first_dim, - is_fp8_activation_recompute_enabled, - in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, _fsdp_gather_tensors, ) @@ -46,6 +44,8 @@ from ..graph import is_graph_capturing from ._common import _apply_normalization, _noop_cat from ..float8_tensor import Float8Tensor +from ..export import is_in_onnx_export_mode +from ..tensor import QuantizedTensor __all__ = ["LayerNormLinear"] @@ -90,6 +90,7 @@ def forward( ub_overlap_rs_dgrad: bool, ub_overlap_ag: bool, ub_name: str, + fp8_output: bool, fsdp_group: Union[dist_group_type, None], ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible @@ -126,8 +127,13 @@ def forward( inputmat, dtype=ln_out_dtype, memory_format=torch.contiguous_format ) + # Objects for FP8 cast fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + ln_out_scale_inv = None + if fp8: + ln_out_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device) + # Launch normalization kernel ln_out, mu, rsigma = _apply_normalization( inputmat, ln_out, @@ -140,6 +146,7 @@ def forward( fwd_ln_sm_margin, zero_centered_gamma, is_grad_enabled, + fp8_scale_inv=ln_out_scale_inv, ) # Column Parallel Linear @@ -172,6 +179,7 @@ def forward( tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, out=ln_out_fp8, + scale_inv=ln_out_scale_inv, ) ln_out = torch.empty_like(ln_out_fp8) else: @@ -180,6 +188,7 @@ def forward( fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, + scale_inv=ln_out_scale_inv, ) if ln_out_gathered: rank = torch.distributed.get_rank(tp_group) @@ -199,7 +208,19 @@ def forward( assert isinstance(weight_fp8, Float8Tensor) - if fp8_meta["recipe"].fp8_mha: + # Hack for ONNX export + # Note: ONNX models are represented as a graph of tensor + # operations, so the in-place scale-inv update doesn't fit + # very well. We work around this by making it look like + # the scale-inv tensor is initialized with a copy. + # Note: ONNX export expects FP8 scales can be represented + # with constant ops. However, copying into a buffer + # involves an expand op for array broadcasting. We work + # around this by filling the buffer instead. + if is_in_onnx_export_mode(): + ln_out_scale_inv.fill_(ln_out_scale_inv.item()) + + if fp8_output: out_index, meta_tensor, output_te_dtype, output_dtype = ( tex.FP8FwdTensors.GEMM1_OUTPUT, fp8_meta["scaling_fwd"], @@ -219,8 +240,8 @@ def forward( 0, weight_fp8._fp8_dtype, ln_out_total, - fp8_meta["scaling_fwd"].scale_inv, - tex.FP8FwdTensors.GEMM1_INPUT, + ln_out_scale_inv, + 0, fp8_dtype_forward, output_dtype, get_workspace(), @@ -306,7 +327,7 @@ def forward( weight_fp8, weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, ln_out if weight.requires_grad else None, - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, + ln_out_scale_inv, ) ctx.activation_dtype = activation_dtype @@ -377,7 +398,7 @@ def backward( weight_fp8, main_grad, ln_out, - fwd_scale_inverses, + ln_out_scale_inv, ) = ctx.saved_tensors # Gather intermediate/activation tensors if needed @@ -570,8 +591,8 @@ def backward( ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) wgrad, _ = tex.fp8_gemm( ln_out_total_t, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM1_INPUT, + ln_out_scale_inv, + 0, fp8_dtype_forward, ( grad_output_t._data @@ -596,8 +617,8 @@ def backward( else: ln_out_total_c = torch.ops.tex_ts.cast_from_fp8_ts( ln_out_total, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM1_INPUT, + ln_out_scale_inv, + 0, fp8_dtype_forward, TE_DType[ctx.activation_dtype], ) @@ -744,6 +765,7 @@ def backward( None, # ub_overlap_rs_dgrad None, # ub_overlap_ag None, # ub_name + None, # fp8_output None, # fsdp_group ) @@ -794,7 +816,7 @@ class LayerNormLinear(TransformerEngineBaseModule): y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will allocated. It is the user's + The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. @@ -810,7 +832,7 @@ class LayerNormLinear(TransformerEngineBaseModule): `set_tensor_parallel_group(tp_group)` method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives. - parallel_mode : {None, 'Column', 'Row'}, default = `None` + parallel_mode : {None, 'column', 'row'}, default = `None` used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described `here `_. When set to `None`, no communication is performed. @@ -1096,6 +1118,7 @@ def forward( self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None, + fp8_output: Optional[bool] = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply layer normalization to the input followed by a linear transformation. @@ -1127,14 +1150,14 @@ def forward( # Get concatenated weight and bias tensors unfused_weights = [getattr(self, name) for name in self.weight_names] - if any(isinstance(w, Float8Tensor) for w in unfused_weights): + if any(isinstance(w, QuantizedTensor) for w in unfused_weights): if self.fp8: if len(unfused_weights) != 1: raise RuntimeError( - "Splitting Float8Tensor into multiple params is not supported" + "Splitting QuantizedTensor into multiple params is not supported" ) else: - unfused_weights = [w.from_float8() for w in unfused_weights] + unfused_weights = [w.dequantize() for w in unfused_weights] weight_tensor = _noop_cat(unfused_weights) if self.use_bias: bias_tensor = _noop_cat( @@ -1146,32 +1169,18 @@ def forward( # Initialize FP8 weights if needed weight_fp8 = None if self.fp8: - with_transpose = torch.is_grad_enabled() - if ( - not with_transpose - and is_fp8_activation_recompute_enabled() - and not in_fp8_activation_recompute_phase() - ): - with_transpose = True if isinstance(weight_tensor, Float8Tensor): - # Fill transpose cache in FP8 tensor if needed - update_transpose_cache = with_transpose - if update_transpose_cache: - update_transpose_cache = ( - is_first_microbatch or skip_fp8_weight_update is not None - ) - if update_transpose_cache: + # Make sure transpose cache is valid, if present + # Note: Transpose cache may have been invalidated + # externally, e.g. by optimizer. + if weight_tensor._transpose is not None: weight_tensor.transpose_2d( fill_cache=True, noop_flag=skip_fp8_weight_update, ) else: # FP8 cast to workspace buffer - update_workspace = ( - is_first_microbatch is None - or is_first_microbatch - or skip_fp8_weight_update is not None - ) + update_workspace = is_first_microbatch is None or is_first_microbatch weight_fp8 = self.get_fp8_workspace( tensor=weight_tensor, fp8_meta_forward=True, @@ -1179,7 +1188,6 @@ def forward( cache_name=(None if is_first_microbatch is None else "weight"), update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, - with_transpose=with_transpose, ) from ..cpu_offload import CPUOffloadEnabled @@ -1223,6 +1231,7 @@ def forward( self.ub_overlap_rs_dgrad, self.ub_overlap_ag, self.ub_name, + fp8_output, self.fsdp_group, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index dc9bef645f..6d5609ccd2 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -42,8 +42,6 @@ allreduce, reduce_scatter_along_first_dim, gather_along_first_dim, - is_fp8_activation_recompute_enabled, - in_fp8_activation_recompute_phase, use_reentrant_activation_recompute, _fsdp_scatter_tensors, _fsdp_gather_tensors, @@ -1195,7 +1193,7 @@ class LayerNormMLP(TransformerEngineBaseModule): y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will allocated. It is the user's + The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. @@ -1485,19 +1483,8 @@ def forward( fc2_weight_fp8 = None if self.fp8: update_workspace = is_first_microbatch is None or is_first_microbatch - with_transpose = torch.is_grad_enabled() - if ( - is_fp8_activation_recompute_enabled() - and not in_fp8_activation_recompute_phase() - ): - with_transpose = True - update_transpose_cache = with_transpose - if update_transpose_cache: - update_transpose_cache = ( - is_first_microbatch or skip_fp8_weight_update is not None - ) if isinstance(fc1_weight, Float8Tensor): - if update_transpose_cache: + if fc1_weight._transpose is not None: fc1_weight.transpose_2d( fill_cache=True, noop_flag=skip_fp8_weight_update, @@ -1513,10 +1500,9 @@ def forward( cache_name=cache_name, update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, - with_transpose=with_transpose, ) if isinstance(fc2_weight, Float8Tensor): - if update_transpose_cache: + if fc2_weight._transpose is not None: fc2_weight.transpose_2d( fill_cache=True, noop_flag=skip_fp8_weight_update, @@ -1532,7 +1518,6 @@ def forward( cache_name=cache_name, update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, - with_transpose=with_transpose, ) # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 68d333262d..8e19a65a28 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -33,8 +33,6 @@ allreduce, reduce_scatter_along_first_dim, gather_along_first_dim, - is_fp8_activation_recompute_enabled, - in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, _fsdp_gather_tensors, ) @@ -48,6 +46,8 @@ from ..jit import no_torch_dynamo from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor +from ..export import is_in_onnx_export_mode +from ..tensor import QuantizedTensor __all__ = ["Linear"] @@ -81,12 +81,10 @@ def forward( ub_overlap_rs: bool, ub_overlap_ag: bool, ub_name: str, - is_first_module_in_mha: bool, + fp8_output: bool, fsdp_group: Union[dist_group_type, None], ) -> torch.Tensor: is_input_fp8 = isinstance(inp, Float8Tensor) - if is_input_fp8: - fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT] = inp._scale_inv[0] # Make sure input dimensions are compatible in_features = weight.shape[-1] @@ -103,19 +101,14 @@ def forward( inputmat = cast_if_needed(inputmat, activation_dtype) inputmat_t = None inputmat_no_fp8 = inputmat + inputmat_scale_inv = None if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) if isinstance(inputmat, Float8Tensor): - if ( - not fp8_meta["recipe"].override_linear_precision.wgrad - and is_grad_enabled - and weight.requires_grad - and not sequence_parallel - ): - # FP8 input for forward, FP8 input transpose for backward wgrad - inputmat_t = inputmat.transpose_2d() + inputmat_scale_inv = inputmat._scale_inv else: + inputmat_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device) if ( not fp8_meta["recipe"].override_linear_precision.wgrad and is_grad_enabled @@ -128,6 +121,7 @@ def forward( fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, + scale_inv=inputmat_scale_inv, ) else: # FP8 input for forward @@ -136,8 +130,21 @@ def forward( fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, + scale_inv=inputmat_scale_inv, ) + # Hack for ONNX export + # Note: ONNX models are represented as a graph of tensor + # operations, so the in-place scale-inv update doesn't fit + # very well. We work around this by making it look like + # the scale-inv tensor is initialized with a copy. + # Note: ONNX export expects FP8 scales can be represented + # with constant ops. However, copying into a buffer + # involves an expand op for array broadcasting. We work + # around this by filling the buffer instead. + if is_in_onnx_export_mode(): + inputmat_scale_inv.fill_(inputmat_scale_inv.item()) + # Column Parallel Linear if parallel_mode == "column" and sequence_parallel: inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) @@ -153,7 +160,7 @@ def forward( assert isinstance(weight_fp8, Float8Tensor) - if is_first_module_in_mha: + if fp8_output: proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( tex.FP8FwdTensors.GEMM1_OUTPUT, fp8_meta["scaling_fwd"], @@ -206,8 +213,8 @@ def forward( if isinstance(inputmat_total, Float8Tensor) else inputmat_total ), - fp8_meta["scaling_fwd"].scale_inv, - tex.FP8FwdTensors.GEMM1_INPUT, + inputmat_scale_inv, + 0, fp8_dtype_forward, proj_out_pttype, get_workspace(), @@ -222,7 +229,7 @@ def forward( fp8_meta_tensor=meta_tensor, D_dtype=proj_out_tetype, ) - if is_first_module_in_mha: + if fp8_output: out = Float8Tensor( data=out, fp8_meta=fp8_meta, @@ -312,10 +319,10 @@ def forward( ctx.save_for_backward( saved_inputmat, saved_inputmat_t, + inputmat_scale_inv, weight, weight_fp8, weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, ) ctx.activation_dtype = activation_dtype @@ -364,10 +371,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ( inputmat, inputmat_t, + inputmat_scale_inv, weight, weight_fp8, main_grad, - fwd_scale_inverses, ) = ctx.saved_tensors # Gather intermediate/activation tensors if needed @@ -520,8 +527,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if isinstance(inputmat_t_total, Float8Tensor) else inputmat_t_total ), - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM1_INPUT, + inputmat_scale_inv, + 0, fp8_dtype_forward, grad_output_t, ctx.fp8_meta["scaling_bwd"].scale_inv, @@ -621,7 +628,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, # ub_overlap_rs None, # ub_overlap_ag None, # ub_name - None, # is_first_module_in_mha + None, # fp8_output None, # fsdp_group ) @@ -643,7 +650,7 @@ class Linear(TransformerEngineBaseModule): used for initializing weights in the following way: `init_method(weight)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. get_rng_state_tracker : Callable, default = `None` - used to get the random number generator state tracker for initilizeing weights. + used to get the random number generator state tracker for initializing weights. rng_tracker_name : str, default = `None` the param passed to get_rng_state_tracker to get the specific rng tracker. parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None @@ -655,7 +662,7 @@ class Linear(TransformerEngineBaseModule): names that end in `_weight` or `_bias`, so trailing underscores are stripped from any provided names. device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will allocated. It is the user's + The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. @@ -671,7 +678,7 @@ class Linear(TransformerEngineBaseModule): `set_tensor_parallel_group(tp_group)` method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives. - parallel_mode : {None, 'Column', 'Row'}, default = `None` + parallel_mode : {None, 'column', 'row'}, default = `None` used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described `here `_. When set to `None`, no communication is performed. @@ -899,7 +906,7 @@ def forward( self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None, - is_first_module_in_mha: Optional[bool] = False, + fp8_output: Optional[bool] = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply the linear transformation to the input. @@ -930,21 +937,19 @@ def forward( with self.prepare_forward( inp, is_first_microbatch, - allow_non_contiguous=isinstance(inp, Float8Tensor), + allow_non_contiguous=isinstance(inp, QuantizedTensor), ) as inp: - is_first_module_in_mha = is_first_module_in_mha and self.fp8_meta["recipe"].fp8_mha - # Get concatenated weight and bias tensors unfused_weights = [getattr(self, name) for name in self.weight_names] - if any(isinstance(w, Float8Tensor) for w in unfused_weights): + if any(isinstance(w, QuantizedTensor) for w in unfused_weights): if self.fp8: if len(unfused_weights) != 1: raise RuntimeError( - "Splitting Float8Tensor into multiple params is not supported" + "Splitting QuantizedTensor into multiple params is not supported" ) else: - unfused_weights = [w.from_float8() for w in unfused_weights] + unfused_weights = [w.dequantize() for w in unfused_weights] weight_tensor = _noop_cat(unfused_weights) if self.use_bias: bias_tensor = _noop_cat( @@ -956,21 +961,11 @@ def forward( # Initialize FP8 weights if needed weight_fp8 = None if self.fp8: - with_transpose = torch.is_grad_enabled() - if ( - not with_transpose - and is_fp8_activation_recompute_enabled() - and not in_fp8_activation_recompute_phase() - ): - with_transpose = True if isinstance(weight_tensor, Float8Tensor): - # Fill transpose cache in FP8 tensor if needed - update_transpose_cache = with_transpose - if update_transpose_cache: - update_transpose_cache = ( - is_first_microbatch or skip_fp8_weight_update is not None - ) - if update_transpose_cache: + # Make sure transpose cache is valid, if present + # Note: Transpose cache may have been invalidated + # externally, e.g. by optimizer. + if weight_tensor._transpose is not None: weight_tensor.transpose_2d( fill_cache=True, noop_flag=skip_fp8_weight_update, @@ -985,7 +980,6 @@ def forward( cache_name=(None if is_first_microbatch is None else "weight"), update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, - with_transpose=with_transpose, fsdp_group=self.fsdp_group, ) @@ -1019,7 +1013,7 @@ def forward( self.ub_overlap_rs, self.ub_overlap_ag, self.ub_name, - is_first_module_in_mha, + fp8_output, self.fsdp_group, ) out = linear_fn(*args) diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index 969a468426..d5dc400206 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -120,7 +120,7 @@ class RMSNorm(torch.nn.Module): .. math:: y = \frac{x}{RMS_\varepsilon(x)} * (1 + \gamma) device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will allocated. It is the user's + The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. """ diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 77efef4ab6..12270d8340 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -9,54 +9,12 @@ import torch -from transformer_engine.pytorch.float8_tensor import Float8Tensor - - -def canonicalize_device(device: Optional[torch.device | str]) -> torch.device: - """Canonicalize PyTorch device - - If `None`, then returns the default CUDA device. - - """ - if device is None: - # Use default CUDA device - device = torch.get_default_device() - if device.type != "cuda": - device = torch.device("cuda", torch.cuda.current_device()) - elif not isinstance(device, torch.device): - device = torch.device(device) - if device.type == "cuda" and device.index is None: - device = torch.device("cuda", torch.cuda.current_device()) - return device - - -def canonicalize_dtype(dtype: Optional[torch.dtype]) -> torch.dtype: - """Canonicalize PyTorch datatype - - If `None`, then returns the default PyTorch datatype. - - """ - if dtype is None: - # Use default dtype - dtype = torch.get_default_dtype() - return dtype - - -def devices_match(device1: torch.device, device2: torch.device) -> bool: - """Whether two devices are the same""" - device1 = torch.device(device1) - device2 = torch.device(device2) - if device1.type != device2.type: - return False - if device1.type == "cuda": - index1 = device1.index - index2 = device2.index - if index1 is None: - index1 = torch.cuda.current_device() - if index2 is None: - index2 = torch.cuda.current_device() - return index1 == index2 - return device1 == device2 +from ..tensor import Float8Tensor +from ..utils import ( + canonicalize_device, # pylint: disable=unused-import + canonicalize_dtype, # pylint: disable=unused-import + devices_match, # pylint: disable=unused-import +) def is_float8_tensor(tensor: Any) -> bool: @@ -92,7 +50,13 @@ def convert_tensor( # Convert FP8 tensor if is_float8_tensor(tensor): - data = tensor._data.to(device=device, memory_format=memory_format) + data = tensor._data + if not devices_match(device, data.device): + data = data.to(device=device) + if memory_format != torch.preserve_format and not data.is_contiguous( + memory_format=memory_format + ): + data = data.contiguous(memory_format=memory_format) return Float8Tensor.make_like( tensor, data=data, diff --git a/transformer_engine/pytorch/ops/basic/all_reduce.py b/transformer_engine/pytorch/ops/basic/all_reduce.py index 622346b1c5..f466ade3a3 100644 --- a/transformer_engine/pytorch/ops/basic/all_reduce.py +++ b/transformer_engine/pytorch/ops/basic/all_reduce.py @@ -9,11 +9,8 @@ import torch -from transformer_engine.pytorch.ops.op import ( - BasicOperation, - OperationContext, -) -from .._common import is_float8_tensor +from ...tensor import QuantizedTensor +from ..op import BasicOperation, OperationContext class AllReduce(BasicOperation): @@ -54,8 +51,8 @@ def op_forward( # Perform all-reduce x = input_ - if is_float8_tensor(x): - x = x.from_float8() + if isinstance(x, QuantizedTensor): + x = x.dequantize() x = x.contiguous() torch.distributed.all_reduce(x, group=self.process_group) return x diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 826807d1c0..ce72dd8a55 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -289,10 +289,18 @@ def reset_parameters(self) -> None: # Cast to FP8 if needed if self._with_fp8_parameters: + dummy_amax = torch.empty( + (1, 1), + dtype=torch.float32, + device=self.device, + ) # Dummy buffer to avoid overwriting amax history weight = Float8Tensor.to_float8( weight, fp8_meta=self.get_fp8_meta("param"), + fp8_meta_forward=True, fp8_meta_index=0, + amax=dummy_amax, + with_transpose_cache=torch.is_grad_enabled(), ) # Save updated parameter @@ -467,25 +475,19 @@ def _functional_forward( input_fp8_meta["recipe"], fprop_tensor=True, ) - x_fp8 = Float8Tensor( - data=torch.empty_like(x_local, dtype=torch.uint8), + with_transpose_cache = weight.requires_grad + if tensor_parallel_mode == "column" and sequence_parallel: + with_transpose_cache = False + x_local = Float8Tensor.to_float8( + x_local, fp8_meta=input_fp8_meta, fp8_meta_forward=True, fp8_meta_index=0, fp8_dtype=fp8_dtype, - fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), - dtype=dtype, + with_transpose_cache=with_transpose_cache, ) - with_cast_transpose = weight.requires_grad - if tensor_parallel_mode == "column" and sequence_parallel: - with_cast_transpose = False - if with_cast_transpose: - x_fp8.cast_transpose_(x_local) - else: - x_fp8.copy_(x_local) - x_local = x_fp8 elif not with_fp8_compute and is_float8_tensor(x_local): - x_local = x_local.from_float8() + x_local = x_local.dequantize() x = x_local x_async = None if tensor_parallel_mode == "column" and sequence_parallel: @@ -510,11 +512,12 @@ def _functional_forward( w = Float8Tensor.to_float8( w, fp8_meta=weight_fp8_meta, + fp8_meta_forward=True, fp8_meta_index=0, fp8_dtype=fp8_dtype, ) elif not with_fp8_compute and is_float8_tensor(w): - w = w.from_float8() + w = w.dequantize() # Check bias tensor b = None @@ -815,25 +818,19 @@ def _functional_backward( grad_output_fp8_meta["recipe"], fprop_tensor=False, ) - dy_fp8 = Float8Tensor( - data=torch.empty_like(dy, dtype=torch.uint8), + with_transpose_cache = weight_requires_grad + if tensor_parallel_mode == "row" and sequence_parallel: + with_transpose_cache = False + dy = Float8Tensor.to_float8( + dy, fp8_meta=grad_output_fp8_meta, fp8_meta_forward=False, fp8_meta_index=0, fp8_dtype=fp8_dtype, - fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), - dtype=dtype, + with_transpose_cache=with_transpose_cache, ) - with_cast_transpose = weight_requires_grad - if tensor_parallel_mode == "row" and sequence_parallel: - with_cast_transpose = False - if with_cast_transpose: - dy_fp8.cast_transpose_(dy) - else: - dy_fp8.copy_(dy) - dy = dy_fp8 elif not with_fp8_compute and is_float8_tensor(dy): - dy = dy.from_float8() + dy = dy.dequantize() if tensor_parallel_mode == "row" and sequence_parallel: dy, dy_async = gather_along_first_dim( dy, @@ -853,26 +850,24 @@ def _functional_backward( device=device, dtype=dtype, ) + x_is_sharded = tensor_parallel_mode == "column" and sequence_parallel if with_fp8_compute and not is_float8_tensor(x_local): fp8_dtype = get_fp8_te_dtype( input_fp8_meta["recipe"], fprop_tensor=True, ) - x_fp8 = Float8Tensor( - data=torch.empty_like(x_local, dtype=torch.uint8), + x_local = Float8Tensor.to_float8( + x_local, fp8_meta=input_fp8_meta, fp8_meta_forward=True, fp8_meta_index=0, fp8_dtype=fp8_dtype, - fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), - dtype=dtype, + with_transpose_cache=(not x_is_sharded), ) - x_fp8.cast_transpose_(x_local) - x_local = x_fp8 elif not with_fp8_compute and is_float8_tensor(x_local): x_local = x_local.from_float8() x = x_local - if tensor_parallel_mode == "column" and sequence_parallel: + if x_is_sharded: x, x_async = gather_along_first_dim( x_local, tensor_parallel_group, @@ -898,19 +893,16 @@ def _functional_backward( weight_fp8_meta["recipe"], fprop_tensor=True, ) - w_fp8 = Float8Tensor( - data=torch.empty_like(w, dtype=torch.uint8), + w = Float8Tensor.to_float8( + w, fp8_meta=weight_fp8_meta, fp8_meta_forward=True, fp8_meta_index=0, fp8_dtype=fp8_dtype, - fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), - dtype=dtype, + with_transpose_cache=True, ) - w_fp8.cast_transpose_(w) - w = w_fp8 elif not with_fp8_compute and is_float8_tensor(w): - w = w.from_float8() + w = w.dequantize() # Construct grad input tensor if grad_input is not None: diff --git a/transformer_engine/pytorch/ops/basic/reduce_scatter.py b/transformer_engine/pytorch/ops/basic/reduce_scatter.py index 996ca2da31..c78dbc2877 100644 --- a/transformer_engine/pytorch/ops/basic/reduce_scatter.py +++ b/transformer_engine/pytorch/ops/basic/reduce_scatter.py @@ -9,12 +9,9 @@ import torch -from transformer_engine.pytorch.float8_tensor import Float8Tensor -from transformer_engine.pytorch.ops.op import ( - BasicOperation, - OperationContext, -) -from .._common import convert_tensor, is_float8_tensor +from ...tensor import Float8Tensor, QuantizedTensor +from ..op import BasicOperation, OperationContext +from .._common import convert_tensor class ReduceScatter(BasicOperation): @@ -63,8 +60,8 @@ def op_forward( # Check input tensor x = input_ - if is_float8_tensor(x): - x = x.from_float8() + if isinstance(x, QuantizedTensor): + x = x.dequantize() x = x.contiguous() # Perform reduce-scatter @@ -96,7 +93,7 @@ def op_backward( # Perform all-gather dy = convert_tensor(grad_output, memory_format=torch.contiguous_format) dx = None - if is_float8_tensor(dy): + if isinstance(dy, Float8Tensor): dx = Float8Tensor.make_like( dy, data=torch.empty( @@ -111,6 +108,8 @@ def op_backward( group=self.process_group, ) else: + if isinstance(dy, QuantizedTensor): + dy = dy.dequantize() dx = torch.empty(input_dims, dtype=dy.dtype, device=dy.device) torch.distributed.all_gather_into_tensor( dx, diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 322b93a1d8..da0ba3328a 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -219,10 +219,8 @@ def step(self, closure=None, grad_scaler=None): assert ( state["master_param"].shape == p.shape ), "Master weights shape must match model weights shape" - else: - state["master_param"] = None - p_master = state["master_param"] + p_master = state.get("master_param", None) p_grad = p.grad if self.master_weights and p_master is not None and p_master.grad is not None: diff --git a/transformer_engine/pytorch/optimizers/multi_tensor_apply.py b/transformer_engine/pytorch/optimizers/multi_tensor_apply.py index b8d6d1f263..191b57eab9 100644 --- a/transformer_engine/pytorch/optimizers/multi_tensor_apply.py +++ b/transformer_engine/pytorch/optimizers/multi_tensor_apply.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Multi-tensor apply entry.""" +from torch.distributed._tensor import DTensor class MultiTensorApply: # pylint: disable=too-few-public-methods @@ -12,6 +13,11 @@ def __init__(self, chunk_size): self.chunk_size = chunk_size def __call__(self, op, noop_flag_buffer, tensor_lists, *args): + for i, ts in enumerate(tensor_lists): + for j, t in enumerate(ts): + if isinstance(t, DTensor): + tensor_lists[i][j] = t._local_tensor + return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args) diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py new file mode 100644 index 0000000000..9987db58e0 --- /dev/null +++ b/transformer_engine/pytorch/permutation.py @@ -0,0 +1,270 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Linear API""" +import warnings +from typing import Tuple +import torch + +import transformer_engine_torch as tex +from .constants import TE_DType +from .float8_tensor import Float8Tensor + + +__all__ = [ + "moe_permute", + "moe_unpermute", +] + + +class _moe_permute(torch.autograd.Function): + """functional Permute""" + + workspace = None + max_expanded_token_num = 0 + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + indices: torch.Tensor, + num_out_tokens: int, + max_token_num: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Empty input check + if not inp.numel(): + return inp, torch.tensor([], device=inp.device) + + # Device check + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert indices.is_cuda, "TransformerEngine needs CUDA." + # Shape check + assert inp.size(0) == indices.size(0), "Permute not possible" + + # Data type check + fp8 = isinstance(inp, Float8Tensor) + if fp8: + dtype = inp._fp8_dtype + fp8_scale_inv = inp._scale_inv + inp = inp._data + else: + dtype = TE_DType[inp.dtype] + if indices.dtype != torch.int32: + warnings.warn( + f"The data type of the input `indices` of Permute is {indices.dtype}! " + "The recommended type is torch.int32." + ) + indices = indices.to(torch.int32) + + topK = indices.size(1) + + input_max_expanded_token_num = max(max_token_num, inp.size(0)) * topK + if _moe_permute.max_expanded_token_num < input_max_expanded_token_num: + _moe_permute.max_expanded_token_num = input_max_expanded_token_num + _moe_permute.workspace = [] + + permuted_act, row_id_map, _moe_permute.workspace = tex.moe_permute_fwd( + inp, + dtype, + indices, + num_out_tokens, + _moe_permute.workspace, + _moe_permute.max_expanded_token_num, + ) + + if fp8: + permuted_act = Float8Tensor( + data=permuted_act, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv + ) + + ctx.row_id_map = row_id_map + ctx.num_tokens = indices.size(0) + ctx.topK = indices.size(1) + ctx.fp8 = fp8 + return permuted_act, row_id_map + + @staticmethod + def backward( + ctx, + permuted_act_grad: torch.Tensor, + _, + ) -> Tuple[torch.Tensor, ...]: + # Empty input check + if not permuted_act_grad.numel(): + return permuted_act_grad, None, None, None + + if not permuted_act_grad.is_contiguous(): + permuted_act_grad = permuted_act_grad.contiguous() + + if ctx.fp8: + assert isinstance( + permuted_act_grad, Float8Tensor + ), "Grad of the output must be in Float8Tensor type for FP8 moe_permute." + dtype = permuted_act_grad._fp8_dtype + fp8_scale_inv = permuted_act_grad._scale_inv + permuted_act_grad = permuted_act_grad._data + else: + dtype = TE_DType[permuted_act_grad.dtype] + + act_grad = None + if ctx.needs_input_grad[0]: + act_grad = tex.moe_permute_bwd( + permuted_act_grad, dtype, ctx.row_id_map, torch.empty(0), ctx.num_tokens, ctx.topK + ) + if ctx.fp8: + act_grad = Float8Tensor( + data=act_grad, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv * ctx.topK + ) + + return act_grad, None, None, None + + +class _moe_unpermute(torch.autograd.Function): + """functional Unpermute""" + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + row_id_map: torch.Tensor, + probs: torch.Tensor, + ) -> torch.Tensor: + # Empty input check + if not inp.numel(): + ctx.probs = probs + return inp + + # None probs check + if probs is not None: + assert probs.is_cuda, "TransformerEngine needs CUDA." + + if probs.dtype != torch.float32: + warnings.warn( + f"The data type of the input `probs` of Unpermute is {probs.dtype}! " + "The recommended type is torch.float32." + ) + probs = probs.to(torch.float32) + + num_tokens = probs.size(0) + topK = probs.size(1) + else: + num_tokens = row_id_map.size(0) + topK = 1 + probs = torch.empty(0) + + # Device check + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert row_id_map.is_cuda, "TransformerEngine needs CUDA." + + # Data type check + fp8 = isinstance(inp, Float8Tensor) + if fp8: + dtype = inp._fp8_dtype + fp8_scale_inv = inp._scale_inv + inp = inp._data + else: + dtype = TE_DType[inp.dtype] + if row_id_map.dtype != torch.int32: + warnings.warn( + f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! " + "The recommended type is torch.int32." + ) + row_id_map = row_id_map.to(torch.int32) + + unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK) + + if fp8: + unpermuted_output = Float8Tensor( + data=unpermuted_output, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv + ) + + ctx.save_for_backward(inp, row_id_map, probs) + ctx.fp8 = fp8 + return unpermuted_output + + @staticmethod + def backward( + ctx, + unpermuted_act_grad: torch.Tensor, + ) -> Tuple[torch.Tensor, None, torch.Tensor]: + # Empty input check + if not unpermuted_act_grad.numel(): + return unpermuted_act_grad, None, ctx.probs + + if not unpermuted_act_grad.is_contiguous(): + unpermuted_act_grad = unpermuted_act_grad.contiguous() + + if ctx.fp8: + assert isinstance( + unpermuted_act_grad, Float8Tensor + ), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute." + dtype = unpermuted_act_grad._fp8_dtype + fp8_scale_inv = unpermuted_act_grad._scale_inv + unpermuted_act_grad = unpermuted_act_grad._data + else: + dtype = TE_DType[unpermuted_act_grad.dtype] + + inp, row_id_map, probs = ctx.saved_tensors + + act_grad = None + if ctx.needs_input_grad[0]: + act_grad, prob_grad = tex.moe_unpermute_bwd( + unpermuted_act_grad, inp, dtype, row_id_map, probs + ) + if ctx.fp8: + act_grad = Float8Tensor(data=act_grad, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv) + if not ctx.needs_input_grad[2]: + prob_grad = None + + return act_grad, None, prob_grad + + +def moe_permute( + inp: torch.Tensor, + indices: torch.Tensor, + num_out_tokens: int = -1, + max_token_num: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Permute the tokens based on the indices. Token with the same index will be grouped together. + + Parameters + ---------- + inp: torch.Tensor + Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. + indices: torch.Tensor + The token to expert indices tensor of shape [num_tokens, topK] and dtype 'int32'. + num_out_tokens: int, default = -1 + The effective output token count, representing the number of tokens not dropped. + By default, set to '-1', meaning no tokens are dropped. + max_token_num: int, default = -1 + The maximum number of tokens, used for workspace allocation. + By default, set to '-1', meaning the calculation of the size of workspace is + automatically taken over by the operator. + """ + return _moe_permute.apply(inp, indices, num_out_tokens, max_token_num) + + +def moe_unpermute( + inp: torch.Tensor, + row_id_map: torch.Tensor, + probs: torch.Tensor = None, +) -> torch.Tensor: + """ + Unpermute a tensor with permuted tokens, and optionally merge the tokens with their + corresponding probabilities. + + Parameters + ---------- + inp: torch.Tensor + Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted. + row_id_map: torch.Tensor + The tensor of a mapping table for sorted indices used to unpermute the tokens, + which is the second output tensor of `Permute`. + probs: torch.Tensor + The tensor of probabilities corresponding to the permuted tokens. If provided, + the unpermuted tokens will be merged with their respective probabilities. + By default, set to an empty tensor, which means that the tokens are directly merged by accumulation. + """ + return _moe_unpermute.apply(inp, row_id_map, probs) diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index e2f15d5d89..034e671150 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -56,7 +56,7 @@ description="Transformer acceleration library - Torch Lib", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension}, - install_requires=["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"], + install_requires=["torch", "flash-attn>=2.0.6,<=2.6.3,!=2.0.9,!=2.1.0"], tests_require=["numpy", "onnxruntime", "torchvision"], ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): diff --git a/transformer_engine/pytorch/softmax.py b/transformer_engine/pytorch/softmax.py index 4fb8a28857..a632851a76 100644 --- a/transformer_engine/pytorch/softmax.py +++ b/transformer_engine/pytorch/softmax.py @@ -23,9 +23,6 @@ def _get_default_causal_mask(mask_type: str, sq: int, sk: int) -> torch.Tensor: """Return the causal upper triangular mask for softmax input""" - if sq == 1: - return torch.zeros((1, sk), dtype=torch.bool, device="cuda") - matrix_identifiers = (mask_type, sq, sk) if matrix_identifiers not in _default_causal_mask: diagonal_offset = sk - sq + 1 if "bottom_right" in mask_type else 1 diff --git a/transformer_engine/pytorch/te_onnx_extensions.py b/transformer_engine/pytorch/te_onnx_extensions.py index 05c1a5a0f5..0fa9401163 100755 --- a/transformer_engine/pytorch/te_onnx_extensions.py +++ b/transformer_engine/pytorch/te_onnx_extensions.py @@ -74,7 +74,7 @@ def is_dtype_bf16(t): return t.type().scalarType() == "BFloat16" -def quantize(g, inputs, scale_inv, fp8_tensor): +def quantize(g, inputs, scale, fp8_tensor): """Helper Function for Quantization""" output_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs) @@ -83,7 +83,7 @@ def quantize(g, inputs, scale_inv, fp8_tensor): if not is_dtype_fp32(inputs): inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT) - scale = g.op("Constant", value_t=torch.tensor(scale_inv[fp8_tensor])) + scale = g.op("Constant", value_t=torch.tensor(1 / scale[fp8_tensor])) q_op = g.op(make_op_name("TRT_FP8QuantizeLinear"), inputs, scale).setType( inputs.type().with_dtype(torch.uint8).with_sizes(output_shape) ) @@ -124,18 +124,18 @@ def compute_in_fp32(g, inp, subgraph, *args, **kwargs): return sg_out -@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") +@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") def onnx_cast_to_fp8(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for cast_to_fp8""" # pylint: disable=unused-argument - return quantize(g, inputs, scale_inv, fp8_tensor) + return quantize(g, inputs, scale, fp8_tensor) -@symbolic_helper.parse_args("v", "v", "v", "v", "fs", "i", "i") +@symbolic_helper.parse_args("v", "fs", "v", "v", "v", "i", "i") def onnx_cast_to_fp8_noalloc(g, inputs, scale, output, amax, scale_inv, fp8_tensor, otype): """ONNX graph for cast_to_fp8_noalloc""" # pylint: disable=unused-argument - return quantize(g, inputs, scale_inv, fp8_tensor) + return quantize(g, inputs, scale, fp8_tensor) @symbolic_helper.parse_args("v", "fs", "i", "i", "i") @@ -145,25 +145,25 @@ def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype): return dequantize(g, inputs, scale_inv, fp8_tensor, otype) -@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") +@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_gelu""" # pylint: disable=unused-argument # TE computes GELU using float32 precision so wrap the GELU subgraph with # conversion to/from float32. gelu = compute_in_fp32(g, inputs, torch.onnx.symbolic_opset9.gelu, "tanh") - if scale_inv: - gelu = quantize(g, gelu, scale_inv, fp8_tensor) + if scale: + gelu = quantize(g, gelu, scale, fp8_tensor) return gelu -@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") +@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") def onnx_fp8_relu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_relu""" # pylint: disable=unused-argument relu = compute_in_fp32(g, inputs, torch.onnx.symbolic_opset9.relu) - if scale_inv: - relu = quantize(g, relu, scale_inv, fp8_tensor) + if scale: + relu = quantize(g, relu, scale, fp8_tensor) return relu @@ -178,13 +178,13 @@ def onnx_swiglu(g: jit_utils.GraphContext, inp, dim): return g.op("Mul", g.op("Sigmoid", first), second) -@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") +@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") def onnx_fp8_swiglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_swiglu""" # pylint: disable=unused-argument swiglu = compute_in_fp32(g, inputs, onnx_swiglu, 1) - if scale_inv: - swiglu = quantize(g, swiglu, scale_inv, fp8_tensor) + if scale: + swiglu = quantize(g, swiglu, scale, fp8_tensor) return swiglu @@ -199,13 +199,13 @@ def onnx_reglu(g: jit_utils.GraphContext, inp, dim): return g.op("Mul", g.op("Relu", first), second) -@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") +@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") def onnx_fp8_reglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_reglu""" # pylint: disable=unused-argument reglu = compute_in_fp32(g, inputs, onnx_reglu, 1) - if scale_inv: - reglu = quantize(g, reglu, scale_inv, fp8_tensor) + if scale: + reglu = quantize(g, reglu, scale, fp8_tensor) return reglu @@ -221,13 +221,13 @@ def onnx_geglu(g: jit_utils.GraphContext, inp, dim): return g.op("Mul", first_gelu, second) -@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") +@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") def onnx_fp8_geglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_geglu""" # pylint: disable=unused-argument geglu = compute_in_fp32(g, inputs, onnx_geglu, 1) - if scale_inv: - geglu = quantize(g, geglu, scale_inv, fp8_tensor) + if scale: + geglu = quantize(g, geglu, scale, fp8_tensor) return geglu @@ -245,7 +245,7 @@ def onnx_fp8_geglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): "v", "fs", "i", - "fs", + "v", "v", "i", "v", @@ -330,7 +330,7 @@ def _ones_like(g, inp, dtype): return one -@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i", "i", "b") +@symbolic_helper.parse_args("v", "v", "v", "f", "fs", "v", "v", "i", "i", "i", "b") def onnx_layernorm_fwd_fp8( g, inputs, @@ -355,7 +355,7 @@ def onnx_layernorm_fwd_fp8( bias = g.op("Cast", bias, to_i=inp_dtype) ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_gamma) - fp8_ln = quantize(g, ln, scale_inv, fp8_tensor) + fp8_ln = quantize(g, ln, scale, fp8_tensor) return fp8_ln @@ -391,7 +391,7 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_ga return ln -@symbolic_helper.parse_args("v", "v", "f", "v", "v", "fs", "i", "i", "i", "b") +@symbolic_helper.parse_args("v", "v", "f", "fs", "v", "v", "i", "i", "i", "b") def onnx_rmsnorm_fwd_fp8( g, inputs, @@ -413,7 +413,7 @@ def onnx_rmsnorm_fwd_fp8( weight = g.op("Cast", weight, to_i=inp_dtype) ln = onnx_rmsnorm_fwd(g, inputs, weight, eps, sm_margin, zero_centered_gamma) - fp8_ln = quantize(g, ln, scale_inv, fp8_tensor) + fp8_ln = quantize(g, ln, scale, fp8_tensor) return fp8_ln diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py new file mode 100644 index 0000000000..2bad862768 --- /dev/null +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Custom tensor classes""" + +from .float8_tensor import Float8Tensor +from .quantized_tensor import QuantizedTensor diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py new file mode 100644 index 0000000000..610523a10d --- /dev/null +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -0,0 +1,972 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tensor class with FP8 data""" +from __future__ import annotations +from typing import Any, Dict, Optional, Tuple +import warnings + +import torch +import transformer_engine_torch as tex + +from transformer_engine_torch import DType as TE_DType +from ..constants import TE_DType as torch_to_transformer_engine_dtype +from ..cpp_extensions import ( + cast_from_fp8, + cast_to_fp8, + fp8_cast_transpose_fused, +) +from ..fp8 import FP8GlobalStateManager +from ..utils import devices_match +from .quantized_tensor import QuantizedTensor + +aten = torch.ops.aten +updated_fp8_params = {} + + +def _make_fp8_attr_property_funcs(name: str) -> Any: + """Make accessors for an FP8 attribute + + We store FP8 attributes in a dictionary so we can share them + between tensors with the same data, e.g. detached tensors. For + convenience, we also expose them as property attributes. This + function creates the accessors for property attributes. + + Parameters + ---------- + name: str + Key in dictionary of FP8 attributes + + """ + + def get_func(self) -> Any: + return self._fp8_attrs[name] + + def set_func(self, value: Any) -> None: + self._fp8_attrs[name] = value + + def del_func(self) -> None: + del self._fp8_attrs[name] + + return dict(fget=get_func, fset=set_func, fdel=del_func) + + +class _FromFloat8Func(torch.autograd.Function): + """Cast from FP8 to other dtype""" + + @staticmethod + def forward( + _ctx: torch.autograd.function.FunctionCtx, # unused + tensor: Float8Tensor, + dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: + return tensor.dequantize(dtype=dtype) + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # Assume that we want gradients in full precision + return grad, None + + +def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None: + """Amax scale and update when there is at least 1 trainable FP8 parameter.""" + param_id = id(param._data) + + if param_id not in FP8GlobalStateManager.fp8_param_to_autocast: + return + + autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id] + + if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params: + return + + if autocast_key in updated_fp8_params: + updated_fp8_params[autocast_key].add(param_id) + else: + updated_fp8_params[autocast_key] = {param_id} + + current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key] + # All FP8 trainable parameters have been updated. + if updated_fp8_params[autocast_key] == current_fp8_params_set: + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True, fp8_weights=True) + del updated_fp8_params[autocast_key] + + +class _ToFloat8Func(torch.autograd.Function): + """Cast to FP8 from other dtype""" + + @staticmethod + def forward( + _ctx: torch.autograd.function.FunctionCtx, # unused + tensor: torch.Tensor, + fp8_meta: Optional[Dict[str, Any]] = None, + fp8_meta_forward: bool = True, + fp8_meta_index: Optional[int] = None, + fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + with_transpose_cache: bool = False, + ) -> Float8Tensor: + + # Tensor attributes + dtype = tensor.dtype + if dtype not in (torch.float32, torch.bfloat16, torch.float16): + dtype = torch.float32 + device = tensor.device + if device.type != "cuda": + device = torch.device("cuda") + + # FP8 data buffer + data = torch.empty(tensor.size(), dtype=torch.uint8, device=device) + + # Check scale + if scale is None and fp8_meta is None: + scale = 1 + if scale is not None: + if isinstance(scale, torch.Tensor): + scale = scale.to(device=device, dtype=torch.float32) + else: + scale = torch.full([1], scale, dtype=torch.float32, device=device) + + # Check scale-inverse + if scale_inv is None: + scale_inv = torch.empty([1], dtype=torch.float32, device=device) + elif not devices_match(scale_inv.device, device) or scale_inv.dtype != dtype: + scale_inv = scale_inv.to(device=device, dtype=torch.float32) + + # Transpose cache + data_transpose = None + if with_transpose_cache: + data_transpose = torch.empty( + (data.size(-1), data.numel() // data.size(-1)), + dtype=torch.uint8, + device=tensor.device, + ) + + # Construct FP8 tensor + out = Float8Tensor( + data=data, + fp8_meta=fp8_meta, + fp8_meta_forward=fp8_meta_forward, + fp8_meta_index=fp8_meta_index, + fp8_dtype=fp8_dtype, + fp8_scale_inv=scale_inv, + dtype=dtype, + data_transpose=data_transpose, + ) + + # Cast to FP8 tensor + out.quantize_(tensor, scale=scale, amax=amax) + + return out + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # Assume that we want gradients in full precision + return grad, None, None, None, None, None, None, None + + +class _IdentityFunc(torch.autograd.Function): + """Identity function + + If constructor keyword-arguments are provided, then construct a + new Float8Tensor using the provided tensor's attributes. + + """ + + @staticmethod + def forward( + ctx, + tensor: Float8Tensor, + init_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + + # Return input tensor if constructor kwargs are not provided + ctx.input_dtype = tensor.dtype + if init_kwargs is None: + return tensor + + # Construct new tensor if constructor kwargs are provided + default_kwargs = dict( + data=tensor._data, + fp8_meta=tensor._fp8_meta, + fp8_meta_forward=tensor._fp8_meta_forward, + fp8_meta_index=tensor._fp8_meta_index, + fp8_dtype=tensor._fp8_dtype, + fp8_scale_inv=tensor._scale_inv, + dtype=tensor.dtype, + ) + for key, val in default_kwargs.items(): + if key not in init_kwargs: + init_kwargs[key] = val + return Float8Tensor(**init_kwargs) + + @staticmethod + def backward(ctx, grad): + return grad.to(ctx.input_dtype), None + + +class _ViewFunc(torch.autograd.Function): + """View function + + View the Float8Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: torch.Tensor, + shape: Tuple[int] = None, + ) -> torch.Tensor: + + # Return input tensor if shape is not provided + ctx.shape = tensor.shape + if shape is None: + return tensor + + # Construct new tensor if shape is provided + if isinstance(tensor, Float8Tensor): + return Float8Tensor.make_like( + tensor, + data=tensor._data.view(*shape), + ) + return tensor.view(*shape) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + + if isinstance(grad, Float8Tensor): + dgrad = Float8Tensor.make_like( + grad, + data=grad._data.view(ctx.shape), + ) + return dgrad, None + return grad.view(ctx.shape), None + + +class _ReshapeFunc(torch.autograd.Function): + """Reshape function + + Reshape the Float8Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: torch.Tensor, + shape: Tuple[int] = None, + ) -> torch.Tensor: + + # Return input tensor if shape is not provided + ctx.shape = tensor.shape + if shape is None: + return tensor + + # Construct new tensor if shape is provided + if isinstance(tensor, Float8Tensor): + return Float8Tensor.make_like( + tensor, + data=tensor._data.reshape(*shape), + ) + return tensor.reshape(*shape) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + + if isinstance(grad, Float8Tensor): + dgrad = Float8Tensor.make_like( + grad, + data=grad._data.reshape(ctx.shape), + ) + return dgrad, None + return grad.reshape(ctx.shape), None + + +class Float8Tensor(QuantizedTensor): + """Experimental tensor class with FP8 data + + The tensor presents as having a standard, higher-precision dtype, + but the data itself is (scaled) FP8. For most tensor operations, + the data will be cast to the nominal dtype before performing the + operation. + + Parameters + ---------- + data: torch.Tensor + Raw FP8 data in a uint8 tensor + fp8_attrs: dict, optional + FP8 metadata, primarily managed by Float8Tensor. If + provided, all other FP8 configuration is ignored. + fp8_meta: dict, optional + FP8 metadata object, primarily managed by TE modules. + fp8_meta_forward: bool, default = `True` + Whether to access the FP8 metadata for the + forward pass. Ignored if fp8_meta is not + provided. + fp8_meta_index: int, optional + Index to access in FP8 meta tensors. Required if + fp8_meta is provided and otherwise ignored. + fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3 + FP8 format. + fp8_scale_inv: torch.Tensor + Reciprocal of the scaling factor applied when + casting to FP8, i.e. the scaling factor that must + be applied when casting from FP8 to higher + precision. Can be inferred from fp8_meta if + provided. + dtype: torch.dtype, default = torch.float32 + Nominal tensor datatype. + + """ + + def __new__( + cls, + *, + data: torch.Tensor, + fp8_attrs: Optional[Dict[str, Any]] = None, + fp8_meta: Optional[Dict[str, Any]] = None, + fp8_meta_forward: bool = True, + fp8_meta_index: Optional[int] = None, + fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, + fp8_scale_inv: Optional[torch.Tensor] = None, + dtype: torch.dtype = torch.float32, + data_transpose: Optional[torch.Tensor] = None, + ): + + # Check that data buffer is valid + if data.element_size() != 1: + raise ValueError( + f"Float8Tensor requires data buffer with 8-bit dtype (got dtype={data.dtype})" + ) + if data.requires_grad: + raise ValueError("Float8Tensor requires non-differentiable data buffer") + if not data.is_cuda: + data = data.cuda() + + # Initialize tensor object + self = torch.Tensor._make_wrapper_subclass( + cls, + data.size(), + strides=data.stride(), + storage_offset=data.storage_offset(), + dtype=dtype, + layout=data.layout, + requires_grad=data.requires_grad, + device=data.device, + ) + self._data: torch.Tensor = data + + # Initialize dict of class attributes + # Note: We store FP8 attributes in a dictionary so we can + # share them between tensors with the same data, e.g. detached + # tensors. + self._fp8_attrs: dict + if fp8_attrs is None: + self._fp8_attrs = {} + else: + self._fp8_attrs = fp8_attrs + return self + + # FP8 meta tensors + if fp8_meta is not None and fp8_meta_index is None: + raise ValueError( + "To initialize Float8Tensor with FP8 meta tensors, " + "the FP8 meta tensor index must also be provided" + ) + self._fp8_meta: Optional[Dict[str, Any]] = fp8_meta + self._fp8_meta_forward: bool = fp8_meta_forward + self._fp8_meta_index: Optional[int] = fp8_meta_index + + # FP8 dtype + assert fp8_dtype in ( + TE_DType.kFloat8E4M3, + TE_DType.kFloat8E5M2, + ), f"Unsupported fp8_dtype {fp8_dtype}." + self._fp8_dtype: TE_DType = fp8_dtype + + # FP8 scale-inverse + if fp8_scale_inv is None and self._fp8_meta is not None: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=self._fp8_meta_forward, + ) + fp8_scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index] + fp8_scale_inv = fp8_scale_inv.detach().view(1).clone() + if fp8_scale_inv is None: + raise ValueError( + "Attempted to initialize Float8Tensor without specifying scale-inverse" + ) + if not isinstance(fp8_scale_inv, torch.Tensor): + fp8_scale_inv = torch.full( + [1], + fp8_scale_inv, + dtype=torch.float32, + device=self._data.device, + ) + if fp8_scale_inv.numel() != 1: + raise ValueError( + "Attempted to initialize Float8Tensor with invalid scale-inverse tensor" + ) + if fp8_scale_inv.dim() != 1: + fp8_scale_inv = fp8_scale_inv.reshape(1) + if ( + not devices_match(fp8_scale_inv.device, self._data.device) + or fp8_scale_inv.dtype != torch.float32 + ): + fp8_scale_inv = fp8_scale_inv.to( + device=self._data.device, + dtype=torch.float32, + ) + self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv + + # FP8 transpose cache + self._transpose: Optional[Float8Tensor] = data_transpose + self._transpose_invalid: bool = self._transpose is None + + return self + + @classmethod + def make_like( + cls, + tensor: Float8Tensor, + *, + data: torch.Tensor, + fp8_attrs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Float8Tensor: + """Use attributes of a Float8Tensor to create another Float8Tensor + + See constructor for list of keyword arguments. + + """ + default_kwargs = dict( + fp8_meta=tensor._fp8_meta, + fp8_meta_forward=tensor._fp8_meta_forward, + fp8_meta_index=tensor._fp8_meta_index, + fp8_dtype=tensor._fp8_dtype, + fp8_scale_inv=tensor._scale_inv, + dtype=tensor.dtype, + ) + for key, val in default_kwargs.items(): + if key not in kwargs: + kwargs[key] = val + return Float8Tensor(data=data, fp8_attrs=fp8_attrs, **kwargs) + + def __repr__(self): + return ( + "Float8Tensor(" + f"fp8_dtype={self._fp8_dtype}, " + f"scale_inv={self._scale_inv.item()}, " + f"data={self.from_float8(dtype=self.dtype)}" + ")" + ) + + def dequantize(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + + # Convert PyTorch dtype to TE dtype + if dtype is None: + dtype = self.dtype + dtype = torch_to_transformer_engine_dtype[dtype] + + # Make sure FP8 data is in expected format + data = self._data + if data.device.type != "cuda": + data = data.cuda() + if not data.is_contiguous(): + data = data.contiguous() + if data.dim() != 2: + data = data.view(1, -1) + + # Cast from FP8 + out = cast_from_fp8( + data.view(1, -1), + None, # fp8_meta_tensor + None, # fp8_tensor + self._fp8_dtype, + dtype, + scale_inv=self._scale_inv, + ) + + # Make sure output is in expected format + if out.size() != self.size(): + out = out.view(self.size()) + return out + + def from_float8(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Construct plain PyTorch tensor from Float8Tensor + + By default the resulting tensor's dtype is the + Float8Tensor's nominal dtype. + """ + return _FromFloat8Func.apply(self, dtype) + + def quantize_( + self, + tensor: torch.Tensor, + *, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + noop_flag: Optional[torch.Tensor] = None, + ) -> Float8Tensor: + """Update FP8 data + + Parameters + ---------- + tensor: torch.Tensor + Tensor to copy from + scale: torch.Tensor, optional + Scaling factor to use for FP8 quantization + amax: torch.Tensor, optional + History of maximum absolute values. The first entry will + be updated with the absmax of `tensor`. + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid performing update + + """ + src = tensor + dst = self + + # In-place operations invalidate transpose cache + self._reset_caches() + + # Special logic if other tensor is Float8Tensor + if isinstance(src, Float8Tensor): + + # Cast to plain tensor if FP8 dtypes don't match + if dst._fp8_dtype != src._fp8_dtype: + return dst.quantize_(src.dequantize()) + + # Directly copy FP8 data + dst._data.copy_(src._data.detach()) + dst._scale_inv.copy_(src._scale_inv.detach()) + if amax is not None or dst._fp8_meta is not None: + src_amax: torch.Tensor + if src._fp8_meta is None: + src_min, src_max = src.dequantize().aminmax() + src_amax = torch.maximum(-src_min, src_max) + else: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=src._fp8_meta_forward, + ) + fp8_meta_index = src._fp8_meta_index + src_amax = src._fp8_meta[fp8_meta_key].amax_history[0, fp8_meta_index] + dst_amax: torch.Tensor + if amax is None: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=dst._fp8_meta_forward, + ) + fp8_meta_index = dst._fp8_meta_index + dst_amax = dst._fp8_meta[fp8_meta_key].amax_history[0, fp8_meta_index] + else: + dst_amax = amax + if dst_amax.dim() > 0: + dst_amax = dst_amax[tuple([0] * dst_amax.dim())] + torch.maximum(src_amax, dst_amax, out=dst_amax) + if dst._transpose is not None: + if src._transpose is None: + dst.transpose_2d(force_compute=True, fill_cache=True) + else: + dst._transpose.copy_(src._transpose) + dst._transpose_invalid = False + return self + + # Convert QuantizedTensor to plain tensor + if isinstance(src, QuantizedTensor): + return dst.quantize_(src.dequantize()) + + # Make sure input is in expected format + if src.size() != dst.size(): + src = src.expand(dst.size()) + if not devices_match(src.device, dst.device): + src = src.to(device=dst.device) + if src.dtype not in (torch.float32, torch.bfloat16, torch.float16): + src = src.float() + if not src.is_contiguous(): + src = src.contiguous() + + # Make sure FP8 scaling factors are in expected format + if scale is not None: + if isinstance(scale, torch.Tensor): + if not devices_match(scale.device, dst.device) or scale.dtype != torch.float32: + scale = scale.to(device=dst.device, dtype=torch.float32) + else: + scale = torch.full([1], scale, dtype=torch.float32, device=dst.device) + if amax is not None: + while amax.dim() < 2: + amax = amax.unsqueeze(0) + if not devices_match(amax.device, dst.device): + raise ValueError( + f"Invalid device for amax (expected {dst.device}, found {amax.device})" + ) + if amax.dtype != torch.float32: + raise ValueError(f"Invalid dtype for amax (expected float32, found {amax.type})") + + # Default FP8 scaling factors + fp8_meta = None + if dst._fp8_meta is None: + if scale is None: + scale = dst._scale_inv.reciprocal() + if amax is None: + amax = torch.empty((1, 1), dtype=torch.float32, device=dst.device) + else: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=dst._fp8_meta_forward, + ) + fp8_meta = dst._fp8_meta[fp8_meta_key] + + # Check local data + if not dst._data.is_contiguous(): + raise RuntimeError("Transformer Engine cast kernels require contiguous data") + + # Perform FP8 cast + if dst._transpose is None: + dst_data = dst._data + if src.dim() != 2: + src = src.view(1, -1) + dst_data = dst_data.view(1, -1) + cast_to_fp8( + src, + fp8_meta, + dst._fp8_meta_index, + dst._fp8_dtype, + out=dst_data, + scale=scale, + amax=amax, + scale_inv=dst._scale_inv, + ) + else: + fp8_cast_transpose_fused( + src.view(-1, src.size(-1)), + fp8_meta, + dst._fp8_meta_index, + dst._fp8_dtype, + cast_out=dst._data, + transpose_out=dst._transpose, + scale=scale, + amax=amax, + scale_inv=dst._scale_inv, + noop_flag=noop_flag, + ) + dst._transpose_invalid = False + + # Callback hook to perform amax reduction after optimizer step + post_optimizer_step_fwd_amax_reduction(self) + + return self + + @classmethod + def to_float8( + cls, + tensor: torch.Tensor, + *, + fp8_meta: Optional[Dict[str, Any]] = None, + fp8_meta_forward: bool = True, + fp8_meta_index: Optional[int] = None, + fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + with_transpose_cache: bool = False, + ): + """Construct Float8Tensor from plain PyTorch tensor""" + return _ToFloat8Func.apply( + tensor, + fp8_meta, + fp8_meta_forward, + fp8_meta_index, + fp8_dtype, + scale, + amax, + scale_inv, + with_transpose_cache, + ) + + def detach(self) -> Float8Tensor: + return Float8Tensor.make_like( + self, + data=self._data, + fp8_attrs=self._fp8_attrs, + ) + + def clone(self) -> Float8Tensor: + data = self._data.detach().clone() + data_transpose = None + if self._transpose is not None: + data_transpose = self._transpose.detach().clone() + return _IdentityFunc.apply( + self, + dict( + data=data, + data_transpose=data_transpose, + ), + ) + + def view(self, *shape: Tuple[int]) -> Float8Tensor: + return _ViewFunc.apply(self, shape) + + def reshape(self, *shape: Tuple[int]) -> Float8Tensor: + return _ReshapeFunc.apply(self, shape) + + def contiguous( + self, + *, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> Float8Tensor: + """Returns tensor with data in provided memory format + + Returns `self` if data is already in correct memory format. + + """ + if self._data.is_contiguous(memory_format=memory_format): + return self + return _IdentityFunc.apply( + self, + {"data": self._data.detach().contiguous(memory_format=memory_format)}, + ) + + def transpose_2d( + self, + *, + force_compute: bool = False, + fill_cache: bool = False, + noop_flag: Optional[torch.Tensor] = None, + cache: Optional[bool] = None, + ) -> torch.Tensor: + """ + 2D transpose with caching support. + + Parameters + ---------- + force_compute: bool, default = `False` + Force computation of transpose. Otherwise use + cached values, if possible. + fill_cache: bool, default = `False` + Cache output tensor for future function calls. + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid updating + cached values, if possible. + cache: bool, deprecated + + """ + + # Handle deprecated cache kwarg + if cache is not None: + msg = ( + "cache kwarg for Float8Tensor.transpose_2d is deprecated, " + "please use force_compute and fill_cache instead" + ) + warnings.warn(msg, DeprecationWarning) + if cache: + force_compute = False + fill_cache = True + else: + force_compute = True + fill_cache = False + + # Need to compute transpose if cache is invalid + need_compute = force_compute + if self._transpose is None: + need_compute = True + elif self._transpose_invalid: + need_compute = True + + # Need to apply transpose kernel if noop flag is applied + if noop_flag is not None: + need_compute = True + + # Return cached transpose if possible + if not need_compute: + return self._transpose + + # Allocate output if needed + data = self._data.contiguous().reshape(-1, self.size(-1)) + out = self._transpose + if out is None: + out = torch.empty( + (data.size(1), data.size(0)), + dtype=torch.uint8, + device=data.device, + ) + noop_flag = None + else: + self._transpose_invalid = False + + # Apply transpose kernel + fp8_dtype = self._fp8_dtype + if noop_flag is None: + tex.fp8_transpose_noalloc(data, out, fp8_dtype) + else: + noop_flag = noop_flag.to(dtype=torch.float32, device=data.device) + tex.fp8_transpose_noalloc_noop(data, out, noop_flag, fp8_dtype) + + # Fill cache if needed + if fill_cache: + self._transpose = out + self._transpose_invalid = False + + return out + + @torch.no_grad() + def cast_transpose_( + self, + tensor: torch.Tensor, + noop_flag: Optional[torch.Tensor] = None, + ) -> None: + """Cast from tensor and populate transpose cache + + Tensor is reshaped as a 2D matrix. + + Parameters + ---------- + tensor: torch.Tensor + Tensor to copy from. Must have same dimensions as + destination tensor. + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid updating + destination tensor. + + """ + if self._transpose is None: + self._transpose = torch.empty( + (self.size(-1), self.numel() // self.size(-1)), + dtype=torch.uint8, + device=self.device, + ) + self.quantize_(tensor, noop_flag=noop_flag) + + @torch.no_grad() + def reset_fp8_meta_scale_inv(self) -> None: + """Replace FP8 meta tensor scale-inverse with cached value + + The FP8 meta tensor scale_inv entry corresponding to this + tensor is replaced with the scale_inv value used to construct + the tensor. + + """ + assert self._fp8_meta is not None, "FP8 meta tensors not found." + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=self._fp8_meta_forward, + ) + self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index].copy_(self._scale_inv[0]) + + def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: + """Create `Float8Tensor` with given nominal dtype + + The new tensor has the same underlying FP8 data. + + """ + return Float8Tensor.make_like( + self, + data=self._data, + fp8_attrs=self._fp8_attrs, + dtype=dtype, + ) + + def _reset_caches(self) -> None: + """ + Set transpose cache as invalid. + Should be called after any in-place operation. + """ + self._transpose_invalid = True + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + # Slice op + if func == aten.slice.Tensor: + tensor = args[0] + data = tensor._data + data_slice = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return Float8Tensor.make_like(tensor, data=data_slice) + + # View op + if func == aten.view.default: + tensor = args[0] + data = tensor._data + data_view = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return Float8Tensor.make_like(tensor, data=data_view) + + # Default case + return super().__torch_dispatch__(func, types, args, kwargs) + + @classmethod + def _make_in_reduce_ex( + cls, + data: torch.Tensor, + fp8_dtype: TE_DType, + fp8_scale_inv: torch.Tensor, + dtype: torch.dtype, + ) -> Float8Tensor: + """Build Float8Tensor, for use in __reduce__ + + __reduce_ex__ assumes object constructor has positional + arguments. + + """ + return Float8Tensor( + data=data, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + dtype=dtype, + ) + + def __reduce_ex__(self, protocol: int) -> tuple: + """Custom pickling to remove references to FP8 metadata objects""" + return ( + Float8Tensor._make_in_reduce_ex, + (self._data, self._fp8_dtype, self._scale_inv, self.dtype), + ) + + def _get_data(self) -> Float8Tensor: + """Get tensor data property""" + return super().data + + def _set_data(self, tensor: torch.Tensor) -> None: + """Set tensor data property + + Cast tensor to FP8 and store in FP8 buffer. + + """ + with torch.no_grad(): + self.copy_(tensor) + + # Cast to FP8 when setting Float8Tensor.data + data = property(_get_data, _set_data) + + # Accessors for objects in self._fp8_attrs + # Note: We store FP8 attributes in a dictionary so we can share + # them between tensors with the same data, e.g. detached tensors. + # For convenience, we also expose them as property attributes. + _fp8_meta = property(**_make_fp8_attr_property_funcs("fp8_meta")) + _fp8_meta_forward = property(**_make_fp8_attr_property_funcs("fp8_meta_forward")) + _fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index")) + _fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype")) + _transpose = property(**_make_fp8_attr_property_funcs("transpose")) + _transpose_invalid = property(**_make_fp8_attr_property_funcs("transpose_invalid")) + _scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv")) diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py new file mode 100644 index 0000000000..f890b0878a --- /dev/null +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -0,0 +1,172 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tensor with quantized data""" + +from __future__ import annotations +from typing import Optional, Tuple + +import torch +from torch.utils._pytree import tree_map + + +class _DequantizeFunc(torch.autograd.Function): + """Autograd function to convert quantized tensor to standard tensor""" + + @staticmethod + def forward( + _ctx: torch.autograd.function.FunctionCtx, # unused + tensor: QuantizedTensor, + dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: + return tensor.dequantize(dtype=dtype) + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + return grad, None + + +class _IdentityFunc(torch.autograd.Function): + """Autograd function to create quantized tensor with same data""" + + @staticmethod + def forward( + _ctx: torch.autograd.function.FunctionCtx, # unused + tensor: QuantizedTensor, + ) -> QuantizedTensor: + return tensor.detach() + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> torch.Tensor: + return grad + + +class QuantizedTensor(torch.Tensor): + """Abstract base class for tensor with quantized data + + This is a proxy class with the interface of a standard PyTorch + tensor, but with data that has been encoded with some quantization + scheme. Derived classes should implement the quantization scheme + by overriding the `quantize_` and `dequantize` functions. + + """ + + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """Convert quantized data to standard PyTorch tensor""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement dequantize function" + ) + + def quantize_(self, tensor: torch.Tensor) -> QuantizedTensor: + """Update quantized data in-place""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement quantize_ function" + ) + + def detach(self) -> QuantizedTensor: + """Create new quantized tensor with same data + + Output tensor must be detached from the current autograd + graph. + + """ + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement detach function" + ) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})" + + def float(self) -> torch.Tensor: + return _DequantizeFunc.apply(self, torch.float32) + + def bfloat16(self) -> torch.Tensor: + return _DequantizeFunc.apply(self, torch.bfloat16) + + def half(self) -> torch.Tensor: + return _DequantizeFunc.apply(self, torch.float16) + + def cpu(self) -> torch.Tensor: + return _DequantizeFunc.apply(self).cpu() + + def expand_as(self, other: torch.Tensor) -> torch.Tensor: + if other is self: + # Note: expand_as is hackily used to create dummy autograd nodes + # and access the backward graph (see + # https://github.com/pytorch/pytorch/blob/238fb660851268f44ff88127887041fea352fe48/torch/nn/parallel/distributed.py#L1026). + # We hackily add a dummy function to handle this case. + return _IdentityFunc.apply(self) + return super().expand_as(other) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + # Detach op + if func == torch.ops.aten.detach.default: + return args[0].detach() + + # In-place copy op + if func == torch.ops.aten.copy_.default: + dst = args[0] + src = args[1] + if isinstance(dst, QuantizedTensor): + dst.quantize_(src) + else: + if isinstance(src, QuantizedTensor): + src = src.dequantize() + dst.copy_(src) + return None + + # View op + if func == torch.ops.aten.view.default: + raise NotImplementedError("{cls.__name__} class does not support tensor views") + + def maybe_unwrap(arg): + if isinstance(arg, QuantizedTensor): + return arg.dequantize(dtype=arg.dtype) + return arg + + def maybe_update_inplace(arg, new_arg, schema_arg): + if ( + isinstance(arg, QuantizedTensor) + and isinstance(new_arg, torch.Tensor) + and hasattr(schema_arg, "alias_info") + and hasattr(schema_arg.alias_info, "is_write") + and schema_arg.alias_info.is_write + ): + arg.quantize_(new_arg) + + # In-place op: dequantize, perform op, and quantize + if func._schema.is_mutable: + new_args = tree_map(maybe_unwrap, args) + new_kwargs = tree_map(maybe_unwrap, kwargs) + schema_args = func._schema.arguments + args_len = len(args) + super().__torch_dispatch__(func, types, new_args, new_kwargs) + for arg, new_arg, schema_arg in zip(args, new_args, schema_args): + maybe_update_inplace(arg, new_arg, schema_arg) + for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]): + assert kwarg == new_kwarg == schema_arg.name, "name of the kw argument should match" + maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg) + return None + + # Default op: dequantize and perform op + args = tree_map(maybe_unwrap, args) + if kwargs is not None: + kwargs = tree_map(maybe_unwrap, kwargs) + out = super().__torch_dispatch__(func, types, args, kwargs) + return out + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + # Do not force the QuantizedTensor type on the returned tensor + return torch._C._disabled_torch_function_impl(func, types, args, kwargs) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index bd6e27594d..020d262be2 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -173,7 +173,7 @@ class TransformerLayer(torch.nn.Module): Type of activation used in MLP block. Options are: 'gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu' and 'srelu'. device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will allocated. It is the user's + The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. attn_input_format: {'sbhd', 'bshd'}, default = 'sbhd' @@ -503,7 +503,13 @@ def set_context_parallel_group( cuda stream for context parallel execution. cp_comm_type : str inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather". + Can be "p2p" or "all_gather" or "a2a". + "p2p": Exchange KV chunks with P2P communications in ring topology. + P2P is async and can be overlapped with attention compute. + "all_gather": All-gather to get full sequence of KV before attention. + The all-gather is not async, and cannot be overlapped. + "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP + group, and gather to get full sequence of QKV. """ # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(self.modules()): diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 5e3fa05f52..d5145455b8 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Utility functions for Transformer Engine modules""" +from __future__ import annotations import functools import math from typing import Any, Callable, Optional, Tuple @@ -251,3 +252,52 @@ def get_cudnn_version() -> Tuple[int, int, int]: major, encoded_version = divmod(encoded_version, major_version_magnitude) minor, patch = divmod(encoded_version, 100) return (major, minor, patch) + + +def canonicalize_device(device: Optional[torch.device | str]) -> torch.device: + """Canonicalize PyTorch device + + If `None`, then returns the default CUDA device. + + """ + if device is None: + # Use default CUDA device + device = torch.get_default_device() + if device.type != "cuda": + device = torch.device("cuda", torch.cuda.current_device()) + elif not isinstance(device, torch.device): + device = torch.device(device) + if device.type == "cuda" and device.index is None: + device = torch.device("cuda", torch.cuda.current_device()) + return device + + +def canonicalize_dtype(dtype: Optional[torch.dtype]) -> torch.dtype: + """Canonicalize PyTorch datatype + + If `None`, then returns the default PyTorch datatype. + + """ + if dtype is None: + # Use default dtype + dtype = torch.get_default_dtype() + return dtype + + +def devices_match(device1: torch.device, device2: torch.device) -> bool: + """Whether two devices are the same""" + device1 = torch.device(device1) + device2 = torch.device(device2) + if device1.type != device2.type: + return False + if device1.type == "cuda": + index1 = device1.index + index2 = device2.index + if index1 == index2: + return True + if index1 is None: + index1 = torch.cuda.current_device() + if index2 is None: + index2 = torch.cuda.current_device() + return index1 == index2 + return device1 == device2