diff --git a/.github/workflows/deploy_nightly_docs.yml b/.github/workflows/deploy_nightly_docs.yml
index 3820dbde50..cd68019c8f 100644
--- a/.github/workflows/deploy_nightly_docs.yml
+++ b/.github/workflows/deploy_nightly_docs.yml
@@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Download artifact
- uses: actions/download-artifact@v3
+ uses: actions/download-artifact@v4.1.7
with:
name: "te_docs"
path: "html"
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index a4fdbdfdfd..d2bd865a8f 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -31,7 +31,7 @@ jobs:
run: |
sudo apt-get update
sudo apt-get install pip -y
- pip install torch
+ pip install torch numpy
export PYTHON_ONLY=1
export TE_PATH=.
bash ./qa/L0_pytorch_lint/test.sh
diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml
index cd47fa9a54..c25aa863ad 100644
--- a/.github/workflows/trigger-ci.yml
+++ b/.github/workflows/trigger-ci.yml
@@ -33,6 +33,9 @@ jobs:
|| github.actor == 'Oleg-Goncharov'
|| github.actor == 'phu0ngng'
|| github.actor == 'xrennvidia'
+ || github.actor == 'yaox12'
+ || github.actor == 'huanghua1994'
+ || github.actor == 'mgoldfarb-nvidia'
)
steps:
- name: Check if comment is issued by authorized person
diff --git a/README.rst b/README.rst
index 25ed8af1de..6cc7eeae8a 100644
--- a/README.rst
+++ b/README.rst
@@ -174,7 +174,15 @@ To install the latest stable version of Transformer Engine,
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
-This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch).
+This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch,paddle).
+
+Alternatively, the package can be directly installed from `Transformer Engine's PyPI `_, e.g.
+
+.. code-block:: bash
+
+ pip install transformer_engine[pytorch]
+
+To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch,paddle]). Transformer Engine ships wheels for the core library as well as the PaddlePaddle extensions. Source distributions are shipped for the JAX and PyTorch extensions.
From source
^^^^^^^^^^^
diff --git a/benchmarks/attention/benchmark_attention.py b/benchmarks/attention/benchmark_attention.py
index bfd7bf8471..cff7c65fbc 100644
--- a/benchmarks/attention/benchmark_attention.py
+++ b/benchmarks/attention/benchmark_attention.py
@@ -156,7 +156,7 @@ def parse_results(per_cudnn, per_flash, model):
df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] = t_cudnn_avg.sum() / 1e6
if per_flash > 0:
- t_flash_all = df[df["Name"].str.contains("void flash")]["Duration (ns)"].to_numpy()
+ t_flash_all = df[df["Name"].str.contains("flash")]["Duration (ns)"].to_numpy()
t_flash_all = t_flash_all.reshape(-1, per_flash)
t_flash_avg = np.average(t_flash_all, axis=0)
df_times.loc[row, "FlashAttention Kernels (fwd)"] = t_flash_avg[0] / 1e6
diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt
index 81c871de46..1cac385c6c 100644
--- a/build_tools/VERSION.txt
+++ b/build_tools/VERSION.txt
@@ -1 +1 @@
-1.10.0
+1.11.0
diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py
index f71cef08ea..af11ada34c 100644
--- a/build_tools/build_ext.py
+++ b/build_tools/build_ext.py
@@ -106,8 +106,12 @@ def run(self) -> None:
if isinstance(ext, CMakeExtension):
print(f"Building CMake extension {ext.name}")
# Set up incremental builds for CMake extensions
- setup_dir = Path(__file__).resolve().parent.parent
- build_dir = setup_dir / "build" / "cmake"
+ build_dir = os.getenv("NVTE_CMAKE_BUILD_DIR")
+ if build_dir:
+ build_dir = Path(build_dir).resolve()
+ else:
+ root_dir = Path(__file__).resolve().parent.parent
+ build_dir = root_dir / "build" / "cmake"
# Ensure the directory exists
build_dir.mkdir(parents=True, exist_ok=True)
diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py
index f932f0695e..4563a0272a 100644
--- a/build_tools/pytorch.py
+++ b/build_tools/pytorch.py
@@ -10,8 +10,9 @@
from .utils import (
all_files_in_dir,
- cuda_version,
+ cuda_archs,
cuda_path,
+ cuda_version,
)
@@ -48,8 +49,6 @@ def setup_pytorch_extension(
]
nvcc_flags = [
"-O3",
- "-gencode",
- "arch=compute_70,code=sm_70",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
@@ -61,6 +60,11 @@ def setup_pytorch_extension(
"--use_fast_math",
]
+ cuda_architectures = cuda_archs()
+
+ if "70" in cuda_architectures:
+ nvcc_flags.extend(["-gencode", "arch=compute_70,code=sm_70"])
+
# Version-dependent CUDA options
try:
version = cuda_version()
@@ -73,17 +77,18 @@ def setup_pytorch_extension(
(
"--threads",
os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"),
- "-gencode",
- "arch=compute_80,code=sm_80",
- "-gencode",
- "arch=compute_90,code=sm_90",
)
)
+ if "80" in cuda_architectures:
+ nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
+ if "90" in cuda_architectures:
+ nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])
+
# Libraries
library_dirs = []
libraries = []
- if os.getenv("NVTE_UB_WITH_MPI"):
+ if bool(int(os.getenv("NVTE_UB_WITH_MPI", 0))):
assert (
os.getenv("MPI_HOME") is not None
), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
diff --git a/build_tools/utils.py b/build_tools/utils.py
index 27ceea844b..d846b87f22 100644
--- a/build_tools/utils.py
+++ b/build_tools/utils.py
@@ -6,12 +6,12 @@
import functools
import glob
+import importlib
import os
import re
import shutil
import subprocess
import sys
-import importlib
from pathlib import Path
from subprocess import CalledProcessError
from typing import List, Optional, Tuple, Union
@@ -188,6 +188,11 @@ def cuda_path() -> Tuple[str, str]:
return cuda_home, nvcc_bin
+@functools.lru_cache(maxsize=None)
+def cuda_archs() -> str:
+ return os.getenv("NVTE_CUDA_ARCHS", "70;80;89;90")
+
+
def cuda_version() -> Tuple[int, ...]:
"""CUDA Toolkit version as a (major, minor) tuple."""
# Query NVCC for version info
diff --git a/docs/_templates/layout.html b/docs/_templates/layout.html
index a68b4531e3..f94e526f57 100644
--- a/docs/_templates/layout.html
+++ b/docs/_templates/layout.html
@@ -1,4 +1,11 @@
{% extends "!layout.html" %}
+
+ {% block extrahead %}
+
+
+
+ {% endblock %}
+
{% block sidebartitle %} {{ super() }}
- {%- if nvidia_analytics_id %}
-
- {%- endif %}
+ {% endblock %}
+
+ {% block footer %}
+
+
{% endblock %}
diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst
index c9504c20af..b097f14475 100644
--- a/docs/api/pytorch.rst
+++ b/docs/api/pytorch.rst
@@ -9,6 +9,9 @@ pyTorch
.. autoapiclass:: transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs)
:members: forward, set_tensor_parallel_group
+.. autoapiclass:: transformer_engine.pytorch.GroupedLinear(in_features, out_features, bias=True, **kwargs)
+ :members: forward, set_tensor_parallel_group
+
.. autoapiclass:: transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs)
.. autoapiclass:: transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs)
@@ -44,3 +47,7 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables
.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context
+
+.. autoapifunction:: transformer_engine.pytorch.moe_permute
+
+.. autoapifunction:: transformer_engine.pytorch.moe_unpermute
diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py
index cb384aa10c..4413bdfd00 100644
--- a/docs/examples/te_llama/te_llama.py
+++ b/docs/examples/te_llama/te_llama.py
@@ -102,8 +102,11 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k
Custom method adapted from `from_pretrained` method in HuggingFace
Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
"""
- vanilla_model = cls(config).to(kwargs["torch_dtype"])
- is_local = os.path.isdir(pretrained_model_name_or_path)
+ # Before loading the model, set the default dtype for torch
+ torch.set_default_dtype(kwargs["torch_dtype"])
+
+ # Load the vanilla model weights
+ vanilla_model = cls(config)
subfolder = ""
variant = None
if os.path.isfile(
@@ -133,7 +136,7 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k
else:
raise AssertionError("Only sharded PyTorch ckpt format supported at the moment")
- resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
+ resolved_archive_file, _ = get_checkpoint_shard_files(
pretrained_model_name_or_path,
archive_file,
)
diff --git a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb
index 57c1bf6601..7013e85ec6 100644
--- a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb
+++ b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb
@@ -247,15 +247,24 @@
"restart_jupyter_notebook()\n",
"\n",
"\n",
- "# Import necessary packages and methods\n",
+ "# Import necessary packages, methods and variables\n",
"from utils import *\n",
"\n",
"\n",
- "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
- "## !!! `model_name` attr must point to the location of the model weights !!!\n",
- "# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n",
- "# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n",
- "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
+ "# Provide Huggingface Access Token\n",
+ "hyperparams.hf_access_token = \"\"\n",
+ "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
+ "\n",
+ "# Provide a directory to cache weights in to avoid downloading them every time.\n",
+ "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
+ "hyperparams.weights_cache_dir = \"\"\n",
+ "\n",
+ "# For Llama 2, uncomment this line (also set by default)\n",
+ "hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
+ "\n",
+ "# For Llama 3, uncomment this line\n",
+ "# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
+ "\n",
"hyperparams.mixed_precision = \"bf16\"\n",
"\n",
"\n",
@@ -554,7 +563,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 1,
"id": "bdb34b91",
"metadata": {},
"outputs": [
@@ -573,15 +582,24 @@
"restart_jupyter_notebook()\n",
"\n",
"\n",
- "# Import necessary packages and methods\n",
+ "# Import necessary packages, methods and variables\n",
"from utils import *\n",
"\n",
"\n",
- "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
- "## !!! `model_name` attr must point to the location of the model weights !!!\n",
- "# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n",
- "# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n",
- "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
+ "# Provide Huggingface Access Token\n",
+ "hyperparams.hf_access_token = \"\"\n",
+ "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
+ "\n",
+ "# Provide a directory to cache weights in to avoid downloading them every time.\n",
+ "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
+ "hyperparams.weights_cache_dir = \"\"\n",
+ "\n",
+ "# For Llama 2, uncomment this line (also set by default)\n",
+ "hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
+ "\n",
+ "# For Llama 3, uncomment this line\n",
+ "# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
+ "\n",
"hyperparams.mixed_precision = \"bf16\"\n",
"\n",
"\n",
@@ -653,15 +671,24 @@
"restart_jupyter_notebook()\n",
"\n",
"\n",
- "# Import necessary packages and methods\n",
+ "# Import necessary packages, methods and variables\n",
"from utils import *\n",
"\n",
"\n",
- "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
- "## !!! `model_name` attr must point to the location of the model weights !!!\n",
- "# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n",
- "# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n",
- "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
+ "# Provide Huggingface Access Token\n",
+ "hyperparams.hf_access_token = \"\"\n",
+ "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
+ "\n",
+ "# Provide a directory to cache weights in to avoid downloading them every time.\n",
+ "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
+ "hyperparams.weights_cache_dir = \"\"\n",
+ "\n",
+ "# For Llama 2, uncomment this line (also set by default)\n",
+ "hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
+ "\n",
+ "# For Llama 3, uncomment this line\n",
+ "# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
+ "\n",
"hyperparams.mixed_precision = \"fp8\"\n",
"\n",
"\n",
diff --git a/docs/examples/te_llama/utils.py b/docs/examples/te_llama/utils.py
index b6b3683d4c..1aebe13afb 100644
--- a/docs/examples/te_llama/utils.py
+++ b/docs/examples/te_llama/utils.py
@@ -25,7 +25,10 @@
class HyperParameters:
def __init__(self):
self.mixed_precision = "bf16"
- # self.model_name = "" # <== Add model weight location here
+
+ # Set to Meta Llama 2 by default.
+ self.model_name = "meta-llama/Llama-2-7b-hf"
+
self.dataset_name = "timdettmers/openassistant-guanaco"
self.dataset_text_field = "text"
self.learning_rate = 1.41e-5
@@ -35,6 +38,10 @@ def __init__(self):
self.num_warmup_steps = 5
self.num_training_steps = 10
+ # This is either provided by the user or it will be set when the
+ # model weights are downloaded.
+ self.weights_cache_dir = ""
+
hyperparams = HyperParameters()
@@ -76,13 +83,49 @@ def tokenize(element):
return train_dataloader
+def ensure_model_is_downloaded(hyperparams):
+ assert hyperparams.model_name in [
+ "meta-llama/Meta-Llama-3-8B",
+ "meta-llama/Llama-2-7b-hf",
+ ], "Only Meta Llama 2 7B and Meta Llama 3 8B models are supported!"
+
+ # Login using Huggingface Hub API
+ from huggingface_hub import login
+
+ try:
+ login(hyperparams.hf_access_token)
+ except Exception as e:
+ if "Invalid token passed!" in str(e):
+ print(
+ "Please pass a valid HF Access Token! More info at"
+ " https://huggingface.co/docs/hub/en/security-tokens."
+ )
+ else:
+ print(f"Exception is {e}")
+
+ # Download the model if it doesn't exist
+ from huggingface_hub import snapshot_download
+
+ supplied_cache_dir = (
+ hyperparams.weights_cache_dir if hyperparams.weights_cache_dir != "" else None
+ )
+ hyperparams.weights_cache_dir = snapshot_download(
+ repo_id=hyperparams.model_name, cache_dir=supplied_cache_dir
+ )
+
+ print(f"Model cache directory : {hyperparams.weights_cache_dir}")
+
+
def init_baseline_model(hyperparams):
+ # Download and cache the weights
+ ensure_model_is_downloaded(hyperparams)
+
# Init the model
- config = AutoConfig.from_pretrained(hyperparams.model_name)
+ config = AutoConfig.from_pretrained(hyperparams.weights_cache_dir)
# make sure to use flash_attention to do iso comparison with TELlamaModel
config._attn_implementation = "flash_attention_2"
model = AutoModelForCausalLM.from_pretrained(
- hyperparams.model_name,
+ hyperparams.weights_cache_dir,
config=config,
torch_dtype=torch.bfloat16,
)
@@ -94,13 +137,16 @@ def init_baseline_model(hyperparams):
def init_te_llama_model(hyperparams):
+ # Download and cache the weights
+ ensure_model_is_downloaded(hyperparams)
+
# Init the model
from te_llama import TELlamaForCausalLM
- config = AutoConfig.from_pretrained(hyperparams.model_name)
+ config = AutoConfig.from_pretrained(hyperparams.weights_cache_dir)
config._attn_implementation = "flash_attention_2"
model = TELlamaForCausalLM.from_pretrained_local(
- hyperparams.model_name,
+ hyperparams.weights_cache_dir,
config=config,
torch_dtype=torch.bfloat16,
)
diff --git a/docs/installation.rst b/docs/installation.rst
index 012f3303cb..9ac0ddf841 100644
--- a/docs/installation.rst
+++ b/docs/installation.rst
@@ -28,6 +28,17 @@ Transformer Engine library is preinstalled in the PyTorch container in versions
on `NVIDIA GPU Cloud `_.
+pip - from PyPI
+-----------------------
+
+Transformer Engine can be directly installed from `our PyPI `_, e.g.
+
+.. code-block:: bash
+
+ pip install transformer_engine[pytorch]
+
+To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch,paddle]). Transformer Engine ships wheels for the core library as well as the PaddlePaddle extensions. Source distributions are shipped for the JAX and PyTorch extensions.
+
pip - from GitHub
-----------------------
diff --git a/examples/jax/encoder/README.md b/examples/jax/encoder/README.md
index f8e150eac6..575f7be6e0 100644
--- a/examples/jax/encoder/README.md
+++ b/examples/jax/encoder/README.md
@@ -1,6 +1,6 @@
# Basic Transformer Encoder Example with Optional FP8 #
-This example uses Transformer Encoder to demonstrate the Transformer Engine usage. And more focus on scaling up training on multiple GPUs. Highly recommend studying the [MNIST example of the Transformer Engine](/examples/jax/mnist) before reading this example. The Transformer Engine is built on top of [Flax](https://github.com/google/flax). Thus, examples use `pjit` to set up multiple GPU training. The basic pjit usage can be referred to [Scale up Flax Modules on multiple devices with pjit](https://flax.readthedocs.io/en/latest/guides/flax_on_pjit.html).
+This example uses Transformer Encoder to demonstrate the Transformer Engine usage. And more focus on scaling up training on multiple GPUs. Highly recommend studying the [MNIST example of the Transformer Engine](/examples/jax/mnist) before reading this example. The Transformer Engine is built on top of [Flax](https://github.com/google/flax). Thus, examples use `jit` `in `in_shardings` and `out_shardings` parameters to set up multiple GPU training. The basic parallel jit usage can be referred to [Scale up Flax Modules on multiple devices](https://flax.readthedocs.io/en/latest/guides/flax_on_pjit.html).
## Single GPU ##
@@ -31,11 +31,11 @@ python test_single_gpu_encoder.py --use-fp8
4. The next is to create sharding rules, mapping the device axis to the logical axis. The `te.extend_logical_axis_rules` under fp8_autocast will return a list of pairs of the mapping, such as `(('batch', 'data'), ...)`. The first is the logical axis and second is the device axis.
-5. Refer structure of `abs_var_collect['params']` and `abs_var_collect['params_axes']` to set up `PartitionSpec` for pjit. All logical axes should be replaced by device axes. If the value of PartitionSpec is None, that means no sharding, broadcasting the data to every device. Note that the `params_axes` attribute is provided by Transformer Engine. The Flax's module doesn't have it, such as `nn.Embed`. For nn.Embed, assigning an empty PartitionSpec is fine because each device has its own embedding layer in DP mode. The `get_params_pspec` routine is used for this purpose. Because each device has a complete model in DP mode, all values of PartitionSpec in params_pspec should be None. This will be different in the model parallelism example.
+5. Refer structure of `abs_var_collect['params']` and `abs_var_collect['params_axes']` to set up `PartitionSpec` for parallel jit. All logical axes should be replaced by device axes. If the value of PartitionSpec is None, that means no sharding, broadcasting the data to every device. Note that the `params_axes` attribute is provided by Transformer Engine. The Flax's module doesn't have it, such as `nn.Embed`. For nn.Embed, assigning an empty PartitionSpec is fine because each device has its own embedding layer in DP mode. The `get_params_pspec` routine is used for this purpose. Because each device has a complete model in DP mode, all values of PartitionSpec in params_pspec should be None. This will be different in the model parallelism example.
-6. Fill in `params_pspec` and `encoder.init` to pjit to get a compiled function, `pjit_encoder_init `, and use it to initialize the model, so JAX now can know how to do the sharding.
+6. Fill in `params_sharding` and `encoder.init` to jit to get a compiled function, `jit_encoder_init `, and use it to initialize the model, so JAX now can know how to do the sharding.
-7. The `train_step` and `eval_step` also need to be compiled by pjit. Thus, every input and output argument has to be set up `PartitionSpec` if the argument contains a tensor. For instance, the `input_pspec` is `PartitionSpec('data', None)` because the input shape is (batch size, sequence length). Then, the rest of the workflow is similar to the previous example.
+7. The `train_step` and `eval_step` also need to be compiled by jit. Thus, every input and output argument has to be set up `PartitionSpec` if the argument contains a tensor. For instance, the `input_pspec` is `PartitionSpec('data', None)` because the input shape is (batch size, sequence length). Then, the rest of the workflow is similar to the previous example.
8. Use `CUDA_VISIBLE_DEVICES` to control the number of GPUs used. For example, if the system has 8 GPUs but only 4 GPUs need to be used, then:
```sh
@@ -84,7 +84,7 @@ python test_model_parallel_encoder.py --use-fp8
1. This example inherits previous model parallelism example, but uses multiprocessing instead of single-program multiple-data (SPMD). It uses 1 GPU per process.
-2. The benefit of multiprocessing is to setup hardware affinity for GPUs, such as NUMA binding. It may help improve performance and stability. Please refer to [Best Practices When Benchmarking CUDA Applications](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2019-s9956/) for more details.
+2. There is two main benefits of multiprocessing: support multi-node and to setup hardware affinity for GPUs, such as NUMA binding. Affinity may help improve performance and stability. Please refer to [Best Practices When Benchmarking CUDA Applications](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2019-s9956/) for more details.
3. The quick way to check system topology is to use `nvidia-smi`, for example:
```sh
diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py
index 646d6e0a12..25d744887e 100644
--- a/examples/jax/encoder/test_model_parallel_encoder.py
+++ b/examples/jax/encoder/test_model_parallel_encoder.py
@@ -17,7 +17,7 @@
from flax.linen import partitioning as nn_partitioning
from flax.training import train_state
from jax.experimental import mesh_utils
-from jax.experimental.pjit import pjit
+from jax.sharding import PartitionSpec, NamedSharding
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
@@ -223,32 +223,36 @@ def check_fp8(state, var_collect, inputs, masks, labels):
)
-def get_params_pspec(sharding_rules, abs_var_collect):
- """Refer params to create params partition spec"""
- rules_dict = {}
- for key, value in sharding_rules:
- rules_dict[key] = value
+def get_params_sharding(sharding_rules, abs_var_collect, mesh):
+ """Refer params to create params sharding"""
+ rules_dict = dict(sharding_rules)
def to_device_axis(logical_axis):
partitions = [rules_dict[key] for key in logical_axis]
- return jax.sharding.PartitionSpec(*partitions)
+ return NamedSharding(mesh, PartitionSpec(*partitions))
params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
- params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes))
- params_axes_pspec = flax.core.unfreeze(params_axes_pspec)
- params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY])
- params_pspec = {**params_pspec, **params_axes_pspec}
- return params_pspec
+ params_axes_sharding = jax.tree_util.tree_map(
+ to_device_axis, nn_partitioning.get_axis_names(params_axes)
+ )
+ params_axes_sharding = flax.core.unfreeze(params_axes_sharding)
+ params_sharding = jax.tree_util.tree_map(
+ lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY]
+ )
+ params_sharding = {**params_sharding, **params_axes_sharding}
+ return params_sharding
-def get_state_pspec(state, params_pspec):
- """Refer params_pspec to create state partition spec"""
+def get_state_sharding(state, params_sharding):
+ """Refer params_sharding to create state sharding"""
def replace_params(x):
- return params_pspec if isinstance(x, dict) else None
+ return params_sharding if isinstance(x, dict) else None
- state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, dict))
- return state_pspec
+ state_sharding = jax.tree_util.tree_map(
+ replace_params, state, is_leaf=lambda x: isinstance(x, dict)
+ )
+ return state_sharding
def train_and_evaluate(args):
@@ -270,7 +274,9 @@ def train_and_evaluate(args):
), f"Test batch size needs to be multiple of {num_gpu_dp}"
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
- with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)):
+ with jax.sharding.Mesh(
+ devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
+ ) as mesh:
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
@@ -291,34 +297,39 @@ def train_and_evaluate(args):
customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules
- params_pspec = get_params_pspec(sharding_rules, abs_var_collect)
- inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
- masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
+ params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh)
+ inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None))
+ masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None))
- in_shardings = (None, inputs_pspec, masks_pspec)
+ in_shardings = (None, inputs_sharding, masks_sharding)
out_shardings = {
- key: params_pspec if key is PARAMS_KEY else None for key in abs_var_collect
+ key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect
}
- pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings)
- var_collect = pjit_encoder_init(init_rngs, inputs, masks)
+ jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings)
+ var_collect = jit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr)
var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
state = train_state.TrainState.create(
apply_fn=encoder.apply, params=params, tx=optimizer
)
- state_pspec = get_state_pspec(state, params_pspec)
- labels_pspec = jax.sharding.PartitionSpec(
- DEVICE_DP_AXIS,
+ state_sharding = get_state_sharding(state, params_sharding)
+ labels_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS))
+
+ in_shardings = (
+ state_sharding,
+ inputs_sharding,
+ masks_sharding,
+ labels_sharding,
+ None,
+ None,
)
+ out_shardings = (state_sharding, None, None, None)
+ jit_train_step = jax.jit(train_step, in_shardings, out_shardings)
- in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None)
- out_shardings = (state_pspec, None, None, None)
- pjit_train_step = pjit(train_step, in_shardings, out_shardings)
-
- in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None)
+ in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
out_shardings = (None, None)
- pjit_eval_step = pjit(eval_step, in_shardings, out_shardings)
+ jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings)
if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
@@ -327,7 +338,7 @@ def train_and_evaluate(args):
if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
- pjit_train_step(state, inputs, masks, labels, var_collect, rngs)
+ jit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED")
return None
@@ -337,11 +348,11 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
- state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step
+ state, train_ds, args.batch_size, rngs, var_collect, jit_train_step
)
test_loss, test_accuracy = eval_model(
- state, test_ds, args.test_batch_size, var_collect, pjit_eval_step
+ state, test_ds, args.test_batch_size, var_collect, jit_eval_step
)
print(
diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py
index 005ae50e72..9d08254f4d 100644
--- a/examples/jax/encoder/test_multigpu_encoder.py
+++ b/examples/jax/encoder/test_multigpu_encoder.py
@@ -17,7 +17,7 @@
from flax.linen import partitioning as nn_partitioning
from flax.training import train_state
from jax.experimental import mesh_utils
-from jax.experimental.pjit import pjit
+from jax.sharding import PartitionSpec, NamedSharding
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
@@ -202,32 +202,36 @@ def check_fp8(state, var_collect, inputs, masks, labels):
)
-def get_params_pspec(sharding_rules, abs_var_collect):
- """Refer params to create params partition spec"""
- rules_dict = {}
- for key, value in sharding_rules:
- rules_dict[key] = value
+def get_params_sharding(sharding_rules, abs_var_collect, mesh):
+ """Refer params to create params sharding"""
+ rules_dict = dict(sharding_rules)
def to_device_axis(logical_axis):
partitions = [rules_dict[key] for key in logical_axis]
- return jax.sharding.PartitionSpec(*partitions)
+ return NamedSharding(mesh, PartitionSpec(*partitions))
params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
- params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes))
- params_axes_pspec = flax.core.unfreeze(params_axes_pspec)
- params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY])
- params_pspec = {**params_pspec, **params_axes_pspec}
- return params_pspec
+ params_axes_sharding = jax.tree_util.tree_map(
+ to_device_axis, nn_partitioning.get_axis_names(params_axes)
+ )
+ params_axes_sharding = flax.core.unfreeze(params_axes_sharding)
+ params_sharding = jax.tree_util.tree_map(
+ lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY]
+ )
+ params_sharding = {**params_sharding, **params_axes_sharding}
+ return params_sharding
-def get_state_pspec(state, params_pspec):
- """Refer params_pspec to create state partition spec"""
+def get_state_sharding(state, params_sharding):
+ """Refer params_sharding to create state sharding"""
def replace_params(x):
- return params_pspec if isinstance(x, dict) else None
+ return params_sharding if isinstance(x, dict) else None
- state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, dict))
- return state_pspec
+ state_sharding = jax.tree_util.tree_map(
+ replace_params, state, is_leaf=lambda x: isinstance(x, dict)
+ )
+ return state_sharding
def train_and_evaluate(args):
@@ -240,7 +244,7 @@ def train_and_evaluate(args):
assert args.test_batch_size % num_gpu == 0, f"Test batch size needs to be multiple of {num_gpu}"
device_mesh = mesh_utils.create_device_mesh((num_gpu,))
- with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)):
+ with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)) as mesh:
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
@@ -260,34 +264,43 @@ def train_and_evaluate(args):
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
sharding_rules = te_flax.extend_logical_axis_rules(tuple())
- params_pspec = get_params_pspec(sharding_rules, abs_var_collect)
- inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
- masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
+ params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh)
+ inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None))
+ masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None))
- in_shardings = (None, inputs_pspec, masks_pspec)
+ in_shardings = (None, inputs_sharding, masks_sharding)
out_shardings = {
- key: params_pspec if key is PARAMS_KEY else None for key in abs_var_collect
+ key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect
}
- pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings)
- var_collect = pjit_encoder_init(init_rngs, inputs, masks)
+ jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings)
+ var_collect = jit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr)
var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
state = train_state.TrainState.create(
apply_fn=encoder.apply, params=params, tx=optimizer
)
- state_pspec = get_state_pspec(state, params_pspec)
- labels_pspec = jax.sharding.PartitionSpec(
- DEVICE_DP_AXIS,
+ state_sharding = get_state_sharding(state, params_sharding)
+ labels_sharding = NamedSharding(
+ mesh,
+ PartitionSpec(
+ DEVICE_DP_AXIS,
+ ),
)
+ in_shardings = (
+ state_sharding,
+ inputs_sharding,
+ masks_sharding,
+ labels_sharding,
+ None,
+ None,
+ )
+ out_shardings = (state_sharding, None, None, None)
+ jit_train_step = jax.jit(train_step, in_shardings, out_shardings)
- in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None)
- out_shardings = (state_pspec, None, None, None)
- pjit_train_step = pjit(train_step, in_shardings, out_shardings)
-
- in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None)
+ in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
out_shardings = (None, None)
- pjit_eval_step = pjit(eval_step, in_shardings, out_shardings)
+ jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings)
if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
@@ -296,7 +309,7 @@ def train_and_evaluate(args):
if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
- pjit_train_step(state, inputs, masks, labels, var_collect, rngs)
+ jit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED")
return None
@@ -306,11 +319,11 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
- state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step
+ state, train_ds, args.batch_size, rngs, var_collect, jit_train_step
)
test_loss, test_accuracy = eval_model(
- state, test_ds, args.test_batch_size, var_collect, pjit_eval_step
+ state, test_ds, args.test_batch_size, var_collect, jit_eval_step
)
print(
diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py
index 286c064e96..e581dbc3f9 100644
--- a/examples/jax/encoder/test_multiprocessing_encoder.py
+++ b/examples/jax/encoder/test_multiprocessing_encoder.py
@@ -19,7 +19,7 @@
from flax.linen import partitioning as nn_partitioning
from flax.training import train_state
from jax.experimental import mesh_utils
-from jax.experimental.pjit import pjit
+from jax.sharding import PartitionSpec, NamedSharding
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
@@ -305,32 +305,36 @@ def check_fp8(state, var_collect, inputs, masks, labels):
)
-def get_params_pspec(sharding_rules, abs_var_collect):
- """Refer params to create params partition spec"""
- rules_dict = {}
- for key, value in sharding_rules:
- rules_dict[key] = value
+def get_params_sharding(sharding_rules, abs_var_collect, mesh):
+ """Refer params to create params sharding"""
+ rules_dict = dict(sharding_rules)
def to_device_axis(logical_axis):
partitions = [rules_dict[key] for key in logical_axis]
- return jax.sharding.PartitionSpec(*partitions)
+ return NamedSharding(mesh, jax.sharding.PartitionSpec(*partitions))
params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
- params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes))
- params_axes_pspec = flax.core.unfreeze(params_axes_pspec)
- params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY])
- params_pspec = {**params_pspec, **params_axes_pspec}
- return params_pspec
+ params_axes_sharding = jax.tree_util.tree_map(
+ to_device_axis, nn_partitioning.get_axis_names(params_axes)
+ )
+ params_axes_sharding = flax.core.unfreeze(params_axes_sharding)
+ params_sharding = jax.tree_util.tree_map(
+ lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY]
+ )
+ params_sharding = {**params_sharding, **params_axes_sharding}
+ return params_sharding
-def get_state_pspec(state, params_pspec):
- """Refer params_pspec to create state partition spec"""
+def get_state_sharding(state, params_sharding):
+ """Refer params_sharding to create state sharding"""
def replace_params(x):
- return params_pspec if isinstance(x, dict) else None
+ return params_sharding if isinstance(x, dict) else None
- state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, dict))
- return state_pspec
+ state_sharding = jax.tree_util.tree_map(
+ replace_params, state, is_leaf=lambda x: isinstance(x, dict)
+ )
+ return state_sharding
def train_and_evaluate(args):
@@ -362,7 +366,7 @@ def train_and_evaluate(args):
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
- ) as shard_mesh:
+ ) as mesh:
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
@@ -383,34 +387,41 @@ def train_and_evaluate(args):
customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules
- params_pspec = get_params_pspec(sharding_rules, abs_var_collect)
+ params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh)
inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
- in_shardings = (None, inputs_pspec, masks_pspec)
+ inputs_sharding = NamedSharding(mesh, inputs_pspec)
+ masks_sharding = NamedSharding(mesh, masks_pspec)
+ in_shardings = (None, inputs_sharding, masks_sharding)
out_shardings = {
- key: params_pspec if key is PARAMS_KEY else None for key in abs_var_collect
+ key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect
}
- pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings)
- var_collect = pjit_encoder_init(init_rngs, inputs, masks)
+ jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings)
+ var_collect = jit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr)
var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
state = train_state.TrainState.create(
apply_fn=encoder.apply, params=params, tx=optimizer
)
- state_pspec = get_state_pspec(state, params_pspec)
- labels_pspec = jax.sharding.PartitionSpec(
- DEVICE_DP_AXIS,
+ state_sharding = get_state_sharding(state, params_sharding)
+ labels_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS))
+
+ in_shardings = (
+ state_sharding,
+ inputs_sharding,
+ masks_sharding,
+ labels_sharding,
+ None,
+ None,
)
+ out_shardings = (state_sharding, None, None, None)
+ jit_train_step = jax.jit(train_step, in_shardings, out_shardings)
- in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None)
- out_shardings = (state_pspec, None, None, None)
- pjit_train_step = pjit(train_step, in_shardings, out_shardings)
-
- in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None)
+ in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
out_shardings = (None, None)
- pjit_eval_step = pjit(eval_step, in_shardings, out_shardings)
+ jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings)
if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
@@ -419,7 +430,7 @@ def train_and_evaluate(args):
if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
- pjit_train_step(state, inputs, masks, labels, var_collect, rngs)
+ jit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED")
else:
for epoch in range(1, args.epochs + 1):
@@ -433,11 +444,11 @@ def train_and_evaluate(args):
args.batch_size,
rngs,
var_collect,
- pjit_train_step,
- shard_mesh,
+ jit_train_step,
+ mesh,
inputs_pspec,
masks_pspec,
- labels_pspec,
+ labels_sharding.spec,
)
test_loss, test_accuracy = eval_model(
@@ -445,11 +456,11 @@ def train_and_evaluate(args):
test_ds,
args.test_batch_size,
var_collect,
- pjit_eval_step,
- shard_mesh,
+ jit_eval_step,
+ mesh,
inputs_pspec,
masks_pspec,
- labels_pspec,
+ labels_sharding.spec,
)
if args.process_id == 0:
print(
diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh
index e6ccf3b82f..b69aed6648 100644
--- a/qa/L0_pytorch_unittest/test.sh
+++ b/qa/L0_pytorch_unittest/test.sh
@@ -23,3 +23,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py
pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops_distributed.py
+pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py
diff --git a/setup.py b/setup.py
index 942f57d3c1..0b0639aea6 100644
--- a/setup.py
+++ b/setup.py
@@ -13,18 +13,17 @@
from wheel.bdist_wheel import bdist_wheel
from build_tools.build_ext import CMakeExtension, get_build_ext
+from build_tools.te_version import te_version
from build_tools.utils import (
+ cuda_archs,
found_cmake,
found_ninja,
found_pybind11,
- remove_dups,
get_frameworks,
install_and_import,
remove_dups,
uninstall_te_wheel_packages,
)
-from build_tools.te_version import te_version
-
frameworks = get_frameworks()
current_file_path = Path(__file__).parent.resolve()
@@ -60,10 +59,11 @@ def setup_common_extension() -> CMakeExtension:
"""Setup CMake extension for common library"""
# Project directory root
root_path = Path(__file__).resolve().parent
+
return CMakeExtension(
name="transformer_engine",
cmake_path=root_path / Path("transformer_engine/common"),
- cmake_flags=[],
+ cmake_flags=["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())],
)
@@ -93,7 +93,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
- install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.5.8,!=2.0.9,!=2.1.0"])
+ install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.6.3,!=2.0.9,!=2.1.0"])
test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"])
if "jax" in frameworks:
install_reqs.extend(["jax", "flax>=0.7.1"])
diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt
index e590d8e92a..45806e7022 100644
--- a/tests/cpp/operator/CMakeLists.txt
+++ b/tests/cpp/operator/CMakeLists.txt
@@ -13,6 +13,7 @@ add_executable(test_operator
test_layernorm.cu
test_rmsnorm.cu
test_multi_cast_transpose.cu
+ test_multi_padding.cu
test_causal_softmax.cu
../test_common.cu)
diff --git a/tests/cpp/operator/test_cast_transpose.cu b/tests/cpp/operator/test_cast_transpose.cu
index 8c168c76f4..39a6614179 100644
--- a/tests/cpp/operator/test_cast_transpose.cu
+++ b/tests/cpp/operator/test_cast_transpose.cu
@@ -69,6 +69,8 @@ void performTest(const size_t N, const size_t H) {
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
+ float ref_scale_inv = 1.f / output_c.scale();
+ compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
diff --git a/tests/cpp/operator/test_cast_transpose_dbias.cu b/tests/cpp/operator/test_cast_transpose_dbias.cu
index a2c8594730..651508c871 100644
--- a/tests/cpp/operator/test_cast_transpose_dbias.cu
+++ b/tests/cpp/operator/test_cast_transpose_dbias.cu
@@ -116,6 +116,8 @@ void performTest(const size_t N, const size_t H) {
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
+ float ref_scale_inv = 1.f / output_c.scale();
+ compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
diff --git a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu
index 5920f23f38..38ac955bc9 100644
--- a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu
+++ b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu
@@ -132,6 +132,8 @@ void performTest(const size_t N, const size_t H) {
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
+ float ref_scale_inv = 1.f / output_c.scale();
+ compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
diff --git a/tests/cpp/operator/test_layernorm.cu b/tests/cpp/operator/test_layernorm.cu
index 07a8a8884c..cdd8e7846c 100644
--- a/tests/cpp/operator/test_layernorm.cu
+++ b/tests/cpp/operator/test_layernorm.cu
@@ -230,6 +230,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma)
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
if (isFp8Type(otype)) {
compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax);
+ float ref_scale_inv = 1.f / z.scale();
+ compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
diff --git a/tests/cpp/operator/test_multi_cast_transpose.cu b/tests/cpp/operator/test_multi_cast_transpose.cu
index 85fc3a573a..e7fb183217 100644
--- a/tests/cpp/operator/test_multi_cast_transpose.cu
+++ b/tests/cpp/operator/test_multi_cast_transpose.cu
@@ -139,6 +139,10 @@ void performTest() {
output_c_list[tensor_id].amax(),
ref_amax_list[tensor_id],
atol_amax, rtol_amax);
+ compareResults("scale_inv",
+ output_c_list[tensor_id].scale_inv(),
+ 1.f / output_c_list[tensor_id].scale(),
+ atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c",
diff --git a/tests/cpp/operator/test_multi_padding.cu b/tests/cpp/operator/test_multi_padding.cu
new file mode 100644
index 0000000000..e9e42725fe
--- /dev/null
+++ b/tests/cpp/operator/test_multi_padding.cu
@@ -0,0 +1,169 @@
+/*************************************************************************
+ * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+
+#include
+#include "../test_common.h"
+
+using namespace transformer_engine;
+
+namespace {
+
+template
+void compute_ref(const std::vector>& input_list,
+ std::vector>& output_list,
+ const std::vector& height_list,
+ const std::vector& width_list,
+ const std::vector& padded_height_list) {
+ using compute_t = float;
+ for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
+ const auto& input = input_list[tensor_id];
+ auto& output = output_list[tensor_id];
+ const size_t height = height_list[tensor_id];
+ const size_t width = width_list[tensor_id];
+ const size_t padded_height = padded_height_list[tensor_id];
+
+ for (size_t i = 0; i < padded_height; ++i) {
+ if (i < height) {
+ for (size_t j = 0; j < width; ++j) {
+ const compute_t x = static_cast(input[i * width + j]);
+ const OutputType y = static_cast(x);
+ output[i * width + j] = y;
+ }
+ } else {
+ for (size_t j = 0; j < width; ++j) {
+ output[i * width + j] = static_cast(0.f);
+ }
+ }
+ }
+ }
+}
+
+template
+void performTest() {
+ using namespace test;
+
+ const DType itype = TypeInfo::dtype;
+ const DType otype = TypeInfo::dtype;
+ const std::vector> tensor_dims = {{1,1},
+ {1,768},
+ {768,1},
+ {768,768},
+ {43,43},
+ {43,256},
+ {256,43},
+ {256,256}};
+ const size_t num_tensors = tensor_dims.size();
+ constexpr int align = 16;
+
+ // Buffers for Transformer Engine implementation
+ std::vector input_list, output_list, output_t_list;
+
+ // Buffers for reference implementation
+ std::vector> ref_input_list;
+ std::vector> ref_output_list;
+ std::vector ref_height_list(num_tensors), ref_width_list(num_tensors);
+ std::vector ref_padded_height_list(num_tensors);
+
+ // Initialize buffers
+ for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
+ const size_t height = tensor_dims[tensor_id].first;
+ const size_t width = tensor_dims[tensor_id].second;
+ const size_t padded_height = (height + align - 1) / align * align;
+ input_list.emplace_back(Tensor({ height, width }, itype));
+ output_list.emplace_back(Tensor({ padded_height, width }, otype));
+
+ auto& input = input_list.back();
+ auto& output = output_list.back();
+ fillUniform(&input);
+ setRandomScale(&output);
+
+ ref_input_list.emplace_back(height*width);
+ ref_output_list.emplace_back(padded_height*width);
+
+ std::copy(input.cpu_dptr(),
+ input.cpu_dptr() + height * width,
+ ref_input_list.back().begin());
+ ref_height_list[tensor_id] = height;
+ ref_width_list[tensor_id] = width;
+ ref_padded_height_list[tensor_id] = padded_height;
+ }
+
+ // Transformer Engine implementation
+ auto make_nvte_vector = [](std::vector& tensor_list)
+ -> std::vector {
+ std::vector nvte_tensor_list;
+ for (auto& tensor : tensor_list) {
+ nvte_tensor_list.emplace_back(tensor.data());
+ }
+ return nvte_tensor_list;
+ };
+ nvte_multi_padding(num_tensors,
+ make_nvte_vector(input_list).data(),
+ make_nvte_vector(output_list).data(),
+ ref_padded_height_list.data(),
+ 0);
+ cudaDeviceSynchronize();
+ auto err = cudaGetLastError();
+ ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
+
+ // Reference implementation
+ compute_ref(ref_input_list,
+ ref_output_list,
+ ref_height_list,
+ ref_width_list,
+ ref_padded_height_list);
+
+ // Check correctness
+ for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
+ auto [atol, rtol] = getTolerances(otype);
+ compareResults("output",
+ output_list[tensor_id],
+ ref_output_list[tensor_id].data(),
+ atol, rtol);
+ }
+}
+
+} // namespace
+
+class MultiPaddingTestSuite
+ : public ::testing::TestWithParam<
+ transformer_engine::DType> {};
+
+TEST_P(MultiPaddingTestSuite, TestMultiPaddingTranspose) {
+ using namespace transformer_engine;
+ using namespace test;
+
+ const DType input_type = GetParam();
+ const DType output_type = input_type;
+
+ TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
+ TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
+ performTest();
+ );
+ );
+}
+
+
+INSTANTIATE_TEST_SUITE_P(
+ OperatorTest,
+ MultiPaddingTestSuite,
+ ::testing::ValuesIn(test::all_fp_types),
+ [](const testing::TestParamInfo& info) {
+ std::string name = test::typeName(info.param);
+ return name;
+ });
diff --git a/tests/cpp/operator/test_rmsnorm.cu b/tests/cpp/operator/test_rmsnorm.cu
index e4e34bac8a..0ec3a877e5 100644
--- a/tests/cpp/operator/test_rmsnorm.cu
+++ b/tests/cpp/operator/test_rmsnorm.cu
@@ -187,6 +187,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma)
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
if (isFp8Type(otype)) {
compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax);
+ float ref_scale_inv = 1.f / z.scale();
+ compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py
index 3a7fe33378..bbd54ecce5 100644
--- a/tests/jax/distributed_test_base.py
+++ b/tests/jax/distributed_test_base.py
@@ -4,6 +4,8 @@
import operator
import re
from functools import reduce
+from itertools import product
+import pytest
import jax
from jax.experimental.pjit import pjit, _UNSPECIFIED
@@ -29,6 +31,28 @@ def generate_configs():
return configs
+def generate_context_parallel_configs():
+ configs = []
+
+ DP_sizes = (1, 2)
+ CP_sizes = (1, 2, 4, 8)
+ TP_sizes = (1, 2)
+ for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes):
+ ndev = cp * tp * dp
+ if is_devices_enough(ndev):
+ configs.append(
+ pytest.param(
+ ndev,
+ (dp, cp, tp),
+ ("dp", "cp", "tp"),
+ MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp"),
+ id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}",
+ )
+ )
+
+ return configs
+
+
COLL_AR_KEY = "all-reduce"
COLL_AG_KEY = "all-gather"
COLL_OTHER_KEY = "other"
diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py
index 15676dd270..61d68aacae 100644
--- a/tests/jax/test_distributed_fused_attn.py
+++ b/tests/jax/test_distributed_fused_attn.py
@@ -3,6 +3,7 @@
# See LICENSE for license information.
import pytest
+from functools import partial
import jax
import jax.numpy as jnp
@@ -10,8 +11,13 @@
from flax.linen import dot_product_attention
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec
-from distributed_test_base import generate_configs, generate_collectives_count, compare_ops
-from utils import make_causal_mask, make_self_mask
+from distributed_test_base import (
+ generate_configs,
+ generate_context_parallel_configs,
+ generate_collectives_count,
+ compare_ops,
+)
+from utils import make_causal_mask, make_self_mask, assert_tree_like_allclose, assert_allclose
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.attention import (
is_fused_attn_kernel_available,
@@ -19,6 +25,10 @@
AttnBiasType,
AttnMaskType,
QKVLayout,
+ QKVFormat,
+ get_qkv_format,
+ reorder_causal_load_balancing,
+ inverse_reorder_causal_load_balancing,
)
@@ -263,7 +273,8 @@ def target_func(q, kv, mask):
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
- )
+ ),
+ dtype=jnp.float32,
)
def ref_func(query, kv, mask):
@@ -284,7 +295,7 @@ def ref_func(query, kv, mask):
dtype=jnp.float32,
)
- return jnp.mean(output).astype(dtype)
+ return jnp.mean(output, dtype=jnp.float32)
(q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, attn_mask_type, dtype
@@ -310,3 +321,229 @@ def ref_func(query, kv, mask):
in_shardings=(q_pspec, kv_pspec, mask_pspec),
out_shardings=(None, (q_pspec, kv_pspec)),
)
+
+
+class TestDistributedContexParallelSelfAttn:
+
+ def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype):
+ batch, seqlen, heads, hidden = shape
+ qkey, kkey, vkey = random.split(random.PRNGKey(1124), 3)
+ q = random.normal(qkey, shape, dtype=dtype)
+ k = random.normal(kkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype)
+ v = random.normal(vkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype)
+
+ mask = None
+ if attn_mask_type == AttnMaskType.CAUSAL_MASK:
+ mask = make_causal_mask(batch, seqlen)
+
+ return q, k, v, mask
+
+ def qkv_to_layout(self, q, k, v, qkv_layout):
+ qkv_args = ()
+ match qkv_layout:
+ case QKVLayout.BSHD_BS2HD:
+ k, v = map(partial(jnp.expand_dims, axis=-3), [k, v])
+ kv = jnp.concatenate((k, v), axis=-3)
+ qkv_args = (q, kv)
+ case QKVLayout.BSHD_BSHD_BSHD:
+ qkv_args = (q, k, v)
+ case _:
+ raise ValueError(f"Unsupported {qkv_layout=}")
+ return qkv_args
+
+ @pytest.mark.parametrize(
+ "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
+ )
+ @pytest.mark.parametrize(
+ "data_shape",
+ [
+ pytest.param([2, 512, 12, 128], id="2-512-12-128"),
+ pytest.param([4, 1024, 16, 64], id="4-1024-16-64"),
+ ],
+ )
+ @pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
+ @pytest.mark.parametrize(
+ "attn_mask_type",
+ [
+ pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
+ pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
+ ],
+ )
+ @pytest.mark.parametrize("dtype", [jnp.bfloat16])
+ @pytest.mark.parametrize(
+ "qkv_layout",
+ [
+ pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
+ pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
+ ],
+ )
+ @pytest.mark.parametrize(
+ "load_balanced", [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")]
+ )
+ def test_contex_parallel_self_attn(
+ self,
+ device_count,
+ mesh_shape,
+ mesh_axes,
+ mesh_resource,
+ data_shape,
+ kv_groups,
+ attn_mask_type,
+ dtype,
+ qkv_layout,
+ load_balanced,
+ ):
+ attn_bias_type = AttnBiasType.NO_BIAS
+ dropout_prob = 0.0
+ is_training = True
+ scaling_factor = 1.0
+ dp_size, cp_size, tp_size = mesh_shape
+ qkv_format = get_qkv_format(qkv_layout)
+
+ _, seqlen, num_head, hidden = data_shape
+ num_kv_heads = num_head // kv_groups
+
+ # make sure the mesh evently divides cp and tp axis
+ if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0:
+ pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")
+
+ def target_func(q, k, v, mask):
+ return jnp.mean(
+ fused_attn(
+ self.qkv_to_layout(q, k, v, qkv_layout),
+ bias=None,
+ mask=mask,
+ seed=None,
+ attn_bias_type=attn_bias_type,
+ attn_mask_type=attn_mask_type,
+ qkv_layout=qkv_layout,
+ scaling_factor=scaling_factor,
+ dropout_probability=dropout_prob,
+ is_training=is_training,
+ context_parallel_causal_load_balanced=load_balanced,
+ ),
+ ).astype(dtype)
+
+ def ref_func(q, k, v, mask, kv_groups):
+ q = jnp.squeeze(q)
+ k = jnp.squeeze(jnp.repeat(k, kv_groups, axis=2))
+ v = jnp.squeeze(jnp.repeat(v, kv_groups, axis=2))
+ output = dot_product_attention(
+ q,
+ k,
+ v,
+ bias=None,
+ mask=mask,
+ deterministic=is_training,
+ dropout_rate=dropout_prob,
+ dropout_rng=None,
+ dtype=jnp.float32,
+ )
+ return jnp.mean(output).astype(dtype)
+
+ q, k, v, mask = self.generate_inputs(data_shape, kv_groups, attn_mask_type, dtype)
+
+ # Single GPU (reference)
+ ref_func_jit = jax.jit(jax.value_and_grad(ref_func, argnums=[0, 1, 2]), static_argnums=[4])
+ ref_fwd, ref_grads = ref_func_jit(q, k, v, mask, kv_groups)
+
+ # Multi GPU (function under test)
+ devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
+ mesh = Mesh(devices, mesh_axes)
+ with mesh, fp8_autocast(mesh_resource=mesh_resource):
+ qkv_ps = PartitionSpec(
+ mesh_resource.dp_resource,
+ mesh_resource.cp_resource,
+ mesh_resource.tp_resource,
+ None,
+ )
+ qkv_sharding = NamedSharding(mesh, qkv_ps)
+
+ mask_ps = PartitionSpec(
+ mesh_resource.dp_resource, None, mesh_resource.cp_resource, None
+ )
+ mask_sharding = NamedSharding(mesh, mask_ps)
+
+ reorder = partial(
+ reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format
+ )
+ inverse_reorder = partial(
+ inverse_reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format
+ )
+
+ if load_balanced:
+ q, k, v = jax.tree.map(reorder, (q, k, v))
+
+ q_, k_, v_ = map(partial(jax.device_put, device=qkv_sharding), [q, k, v])
+ mask_ = jax.device_put(mask, device=mask_sharding)
+
+ target_func_jit = jax.jit(
+ jax.value_and_grad(target_func, argnums=[0, 1, 2]),
+ in_shardings=[qkv_sharding, qkv_sharding, qkv_sharding, mask_sharding],
+ out_shardings=(None, (qkv_sharding, qkv_sharding, qkv_sharding)),
+ )
+
+ target_fwd, target_grads = target_func_jit(q_, k_, v_, mask_)
+
+ if load_balanced:
+ target_dq, target_dk, target_dv = jax.tree.map(inverse_reorder, target_grads[0:3])
+ target_grads = (target_dq, target_dk, target_dv, *target_grads[3:])
+
+ def _print_diffs(target, ref):
+ print("min: ", jnp.min(target), jnp.min(ref))
+ print("max: ", jnp.max(target), jnp.max(ref))
+ print("mean: ", jnp.mean(target), jnp.mean(ref))
+ print("median: ", jnp.median(target), jnp.median(ref))
+ print("std: ", jnp.std(target), jnp.std(ref))
+ print("var: ", jnp.var(target), jnp.var(ref))
+ print("max diff: ", jnp.max(jnp.abs(target - ref)))
+
+ has_diffs = False
+
+ try:
+ assert_allclose(target_fwd, ref_fwd, dtype=dtype)
+ except AssertionError as e:
+ has_diffs = True
+ print(f"target_fwd v. ref_fwd")
+ _print_diffs(target_fwd, ref_fwd)
+
+ for i in range(len(target_grads)):
+ if ref_grads[i] is None or target_grads[i] is None:
+ # expect both none if one is
+ assert target_grads[i] is None and ref_grads[i] is None
+ else:
+ try:
+ assert_allclose(target_grads[i], ref_grads[i])
+ except AssertionError as e:
+ has_diffs = True
+ print(f"target_grads[{i}] v. ref_grads[{i}]")
+ _print_diffs(target_grads[i], ref_grads[i])
+
+ assert has_diffs == False, "has_diffs != False"
+
+
+class TestReorderCausalLoadBalancing:
+ @pytest.mark.parametrize("cp_size", [2, 4, 8])
+ @pytest.mark.parametrize(
+ "shape",
+ [
+ pytest.param([1, 16, 1, 1], id="1-16-1-1"),
+ pytest.param([4, 32, 12, 32], id="4-32-12-32"),
+ pytest.param([3, 32, 8, 64], id="3-32-8-64"),
+ ],
+ )
+ @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD])
+ def test(self, cp_size, shape, qkv_format):
+ tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16)
+ if qkv_format == QKVFormat.SBHD:
+ tensor = tensor.swapaxes(0, 1)
+
+ ref = tensor.copy()
+
+ reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2])
+ inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2])
+
+ reordered = reorder(tensor, cp_size, qkv_format)
+ inversed = inverse(reordered, cp_size, qkv_format)
+
+ assert jnp.array_equal(inversed, ref)
diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py
index 390a3e2c4e..2709eeab30 100644
--- a/tests/jax/test_fused_attn.py
+++ b/tests/jax/test_fused_attn.py
@@ -29,7 +29,10 @@
get_qkv_format,
)
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
-from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend
+from transformer_engine.transformer_engine_jax import (
+ NVTE_Fused_Attn_Backend,
+ get_cudnn_version,
+)
from utils import assert_allclose
@@ -230,7 +233,14 @@ def customcall_fused_dpa(
kwargs.pop("max_segments_per_seq")
return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
return fused_attn_thd(
- qkv_args, bias, seqlens_q, seqlens_kv, offsets_q, offsets_kv, dropout_rng, **kwargs
+ qkv_args,
+ bias,
+ seqlens_q,
+ seqlens_kv,
+ offsets_q,
+ offsets_kv,
+ dropout_rng,
+ **kwargs,
).astype(query.dtype)
@@ -265,6 +275,15 @@ class FusedAttnRunner:
qkv_layout: QKVLayout
bias_shape: BiasShape
+ # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
+ # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
+ def _get_max_segments_per_sequence(self):
+ if 90400 <= get_cudnn_version() < 90500:
+ return self.num_segments_per_seq
+ else:
+ # +1 for testing runtime_segments < max_segments
+ return self.num_segments_per_seq + 1
+
def _check_configs(self):
# TODO(rewang): probably adds this in is_fused_attn_available
if get_qkv_format(self.qkv_layout) == QKVFormat.THD and not self.attn_mask_type in [
@@ -299,7 +318,10 @@ def _check_configs(self):
self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
and self.bias_shape != BiasShape.BIAS_1HSS
):
- if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
+ if self.attn_mask_type not in [
+ AttnMaskType.NO_MASK,
+ AttnMaskType.CAUSAL_MASK,
+ ]:
pytest.skip(
"B1SS, BHSS and 11SS bias shapes are only supported for "
"AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK."
@@ -316,7 +338,12 @@ def _setup_inputs(self):
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
- k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim)
+ k_shape = v_shape = (
+ self.batch_size,
+ self.max_seqlen_kv,
+ self.num_heads_kv,
+ self.head_dim,
+ )
if self.attn_bias_type == AttnBiasType.NO_BIAS:
bias_shape = None
@@ -325,7 +352,12 @@ def _setup_inputs(self):
elif self.bias_shape == BiasShape.BIAS_B1SS:
bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
elif self.bias_shape == BiasShape.BIAS_BHSS:
- bias_shape = (self.batch_size, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
+ bias_shape = (
+ self.batch_size,
+ self.num_heads_q,
+ self.max_seqlen_q,
+ self.max_seqlen_kv,
+ )
elif self.bias_shape == BiasShape.BIAS_11SS:
bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
else:
@@ -405,7 +437,10 @@ def generate_random_segment_ids(
self.segment_pad_kv = self.segment_pad_q
else:
self.token_kv, self.segment_pad_kv = generate_random_segment_ids(
- self.batch_size, self.max_seqlen_kv, self.num_segments_per_seq, seed=2024
+ self.batch_size,
+ self.max_seqlen_kv,
+ self.num_segments_per_seq,
+ seed=2024,
)
self.pad_q = self.segment_pad_q
self.pad_kv = self.segment_pad_kv
@@ -464,8 +499,7 @@ def test_forward(self):
"dropout_probability": self.dropout_prob,
"is_training": self.is_training,
"qkv_layout": self.qkv_layout,
- # +1 for testing runtime_segments < max_segments
- "max_segments_per_seq": self.num_segments_per_seq + 1,
+ "max_segments_per_seq": self._get_max_segments_per_sequence(),
}
# Convert the outputs to float32 for the elementwise comparison
@@ -522,7 +556,7 @@ def grad_func(func, *args, **kwargs):
"dropout_probability": self.dropout_prob,
"is_training": self.is_training,
"qkv_layout": self.qkv_layout,
- "max_segments_per_seq": self.num_segments_per_seq + 1,
+ "max_segments_per_seq": self._get_max_segments_per_sequence(),
}
# We can compute dBias only for the [1, h, s, s] layout
@@ -635,7 +669,16 @@ def check_dqkv(primitive, reference, pad):
pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
pytest.param(4, 512, 128, 16, 16, 64, jnp.bfloat16, id="4-512-128-16-16-64-BF16-CROSS"),
- pytest.param(2, 2048, 1024, 12, 12, 64, jnp.bfloat16, id="2-2048-1048-12-12-64-BF16-CROSS"),
+ pytest.param(
+ 2,
+ 2048,
+ 1024,
+ 12,
+ 12,
+ 64,
+ jnp.bfloat16,
+ id="2-2048-1048-12-12-64-BF16-CROSS",
+ ),
pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"),
pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
],
diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py
index 82a3c8576b..d110dece53 100644
--- a/tests/pytorch/fused_attn/test_fused_attn.py
+++ b/tests/pytorch/fused_attn/test_fused_attn.py
@@ -22,6 +22,7 @@
get_attention_backend,
_flash_attn_2_plus,
_flash_attn_2_3_plus,
+ _flash_attn_3_plus,
check_set_window_size,
AttentionParams,
_attention_backends,
@@ -135,7 +136,6 @@ def _get_attention_backends(
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
- global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
alibi_slopes_shape = None
@@ -233,9 +233,9 @@ def test_dot_product_attention(
"""Test DotProductAttention module"""
# Get configs
- tols = dict(atol=5e-3, rtol=5e-3)
+ tols = dict(atol=1e-3, rtol=1e-3)
if dtype == torch.bfloat16:
- tols = dict(atol=2.5e-2, rtol=2.5e-2)
+ tols = dict(atol=1.5e-2, rtol=1.5e-2)
config = model_configs[model]
is_mla = config.head_dim_qk != config.head_dim_v
if qkv_layout is None:
@@ -420,6 +420,10 @@ def test_dpa_mla(dtype, model_configs, model):
"mask_8_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
),
+ "mask_9_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"),
+ "mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"),
+ "mask_10_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"),
+ "mask_10_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"),
}
@@ -674,7 +678,6 @@ def _run_dot_product_attention(
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
- global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
# Create seqlens
@@ -1032,7 +1035,7 @@ def test_transformer_layer(
# Get configs
config = model_configs[model]
- tols = dict(atol=5e-1, rtol=5e-2)
+ tols = dict(atol=5e-2, rtol=5e-2)
workspace_opt = True
# Test backend availability
@@ -1163,7 +1166,6 @@ def _run_transformer_layer(
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
- global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
# Create input tensor
@@ -1301,6 +1303,7 @@ def _run_transformer_layer(
"fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_13": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"),
"fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"),
+ "fp8_15": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
}
param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
@@ -1312,6 +1315,27 @@ def _rmse(a, b):
return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum())
+def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
+ logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item()))
+ logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item()))
+ try:
+ torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
+ except Exception as e:
+ logging.debug(e)
+
+ rmse = _rmse(a, b)
+ logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse))
+ rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
+ assert rmse < rmse_tol * rmse_range, (
+ name_a
+ + " vs "
+ + name_b
+ + " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
+ rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
+ )
+ )
+
+
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@@ -1320,86 +1344,76 @@ def _rmse(a, b):
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
@pytest.mark.parametrize("input_layernorm", [True, False])
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
-def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd):
- os.environ["NVTE_FLASH_ATTN"] = "0"
- os.environ["NVTE_FUSED_ATTN"] = "1"
+@pytest.mark.parametrize("RoPE", [True, False])
+@pytest.mark.parametrize("is_training", [True, False])
+def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training):
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
- global _attention_backends
- _attention_backends["backend_selection_requires_update"] = True
+ os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model]
- os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
+ if _flash_attn_3_plus and not is_training:
+ if RoPE:
+ pytest.skip("Flash Attention doesn't support FP8 MHA with RoPE.")
+ os.environ["NVTE_FLASH_ATTN"] = "1"
+ os.environ["NVTE_FUSED_ATTN"] = "0"
+ _attention_backends["backend_selection_requires_update"] = True
+ logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
+ flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
+ dtype, config, True, qkv_format, input_layernorm, RoPE, is_training
+ )
+ os.environ["NVTE_FLASH_ATTN"] = "0"
+ os.environ["NVTE_FUSED_ATTN"] = "1"
+ _attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
- dtype, config, True, qkv_format, input_layernorm
+ dtype, config, True, qkv_format, input_layernorm, RoPE, is_training
)
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
- dtype, config, False, qkv_format, input_layernorm
- )
-
- tols = dict(atol=5e-1, rtol=5e-1)
- rmse_tol = 0.1
- fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16)
- fwd_range = max(fused_attn_fwd_fp8.max().item(), fused_attn_fwd_f16.max().item()) - min(
- fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.min().item()
+ dtype, config, False, qkv_format, input_layernorm, RoPE, is_training
)
+ atol = 5e-1
+ rtol = 5e-1
+ rmse_tol = 0.15
logging.debug("========== {:^25s} ==========".format("forward output"))
- logging.debug(
- "fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format(
- fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item()
+ if _flash_attn_3_plus and not is_training:
+ _error(
+ flash_attn_fwd_fp8,
+ fused_attn_fwd_f16,
+ "flash_attn_fwd_fp8",
+ "fused_attn_fwd_f16",
+ atol,
+ rtol,
+ rmse_tol,
)
+ _error(
+ fused_attn_fwd_fp8,
+ fused_attn_fwd_f16,
+ "fused_attn_fwd_fp8",
+ "fused_attn_fwd_f16",
+ atol,
+ rtol,
+ rmse_tol,
)
- logging.debug(
- "fused_attn_fwd_f16 min {:.6f} max {:.6f}".format(
- fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item()
- )
- )
- logging.debug("fused_attn_fwd RMSE: {:.6f}".format(fwd_rmse))
- try:
- torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols)
- except Exception as e:
- logging.debug(e)
- assert (
- fwd_rmse < rmse_tol * fwd_range
- ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
- fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range
- )
- for i in range(len(param_names[:1])):
- bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i])
- bwd_range = max(
- fused_attn_bwd_fp8[i].max().item(), fused_attn_bwd_f16[i].max().item()
- ) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_f16[i].min().item())
-
- logging.debug("========== {:^25s} ==========".format(param_names[i]))
- logging.debug(
- "fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}".format(
- i, fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()
- )
- )
- logging.debug(
- "fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}".format(
- i, fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()
+ if is_training:
+ for i in range(len(param_names[:1])):
+ logging.debug("========== {:^25s} ==========".format(param_names[i]))
+ _error(
+ fused_attn_bwd_fp8[i],
+ fused_attn_bwd_f16[i],
+ f"fused_attn_bwd_fp8[{i}]",
+ f"fused_attn_bwd_f16[{i}]",
+ atol,
+ rtol,
+ rmse_tol,
)
- )
- logging.debug("fused_attn_bwd RMSE[{}]: {:.6f}".format(i, bwd_rmse))
- try:
- torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols)
- except Exception as e:
- logging.debug(e)
-
- assert (
- bwd_rmse < rmse_tol * bwd_range
- ), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
- bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range
- )
-def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
+def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training):
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
@@ -1418,6 +1432,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
)
with fp8_model_init(enabled=fp8_mha):
+ rotary_pos_emb = None
+ if RoPE:
+ PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
+ rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
mha = MultiheadAttention(
hidden_size=config.hidden_size,
num_attention_heads=config.num_heads,
@@ -1434,6 +1452,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
qkv_weight_interleaved=True,
qkv_format=qkv_format,
).to(dtype=dtype, device="cuda")
+ if not is_training:
+ mha = mha.eval()
seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
@@ -1464,7 +1484,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda")
hidden_states = tensor.view(*tensor.shape[:-2], -1)
- hidden_states.requires_grad = True
+ if is_training:
+ hidden_states.requires_grad = True
tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
out_grad = tensor.view(*tensor.shape[:-2], -1)
@@ -1475,8 +1496,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=None,
+ rotary_pos_emb=rotary_pos_emb,
)
- out.backward(out_grad)
+ if is_training:
+ out.backward(out_grad)
param_names = []
param_names.append("hidden_states.grad")
@@ -1487,7 +1510,9 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
param_names.append(name + ".grad")
params.append(param)
- return out, param_names, tuple(x.grad for x in params)
+ if is_training:
+ return out, param_names, tuple(x.grad for x in params)
+ return out, param_names, tuple(None for x in params)
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@@ -1497,7 +1522,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
-def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd):
+@pytest.mark.parametrize("is_training", [True, False])
+def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
config = model_configs_fp8_vs_f16[model]
if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
@@ -1505,76 +1531,68 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd):
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
- global _attention_backends
- _attention_backends["backend_selection_requires_update"] = True
+ if _flash_attn_3_plus and not is_training:
+ os.environ["NVTE_FLASH_ATTN"] = "1"
+ os.environ["NVTE_FUSED_ATTN"] = "0"
+ _attention_backends["backend_selection_requires_update"] = True
+ logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
+ flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
+ dtype, config, True, qkv_layout, is_training
+ )
+
+ os.environ["NVTE_FLASH_ATTN"] = "0"
+ os.environ["NVTE_FUSED_ATTN"] = "1"
+ _attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
- fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(dtype, config, True, qkv_layout)
+ fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
+ dtype, config, True, qkv_layout, is_training
+ )
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False")
- fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(dtype, config, False, qkv_layout)
+ fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(
+ dtype, config, False, qkv_layout, is_training
+ )
- tols = dict(atol=5e-1, rtol=5e-2)
+ atol = 5e-1
+ rtol = 5e-2
rmse_tol = 0.1
bwd_names = ["dq", "dk", "dv"]
- fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16)
- fwd_range = max(fused_attn_fwd_fp8.max().item(), fused_attn_fwd_f16.max().item()) - min(
- fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.min().item()
- )
-
logging.debug("========== {:^25s} ==========".format("forward output"))
- logging.debug(
- "fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format(
- fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item()
- )
- )
- logging.debug(
- "fused_attn_fwd_f16 min {:.6f} max {:.6f}".format(
- fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item()
+ if _flash_attn_3_plus and not is_training:
+ _error(
+ flash_attn_fwd_fp8,
+ fused_attn_fwd_f16,
+ "flash_attn_fwd_fp8",
+ "fused_attn_fwd_f16",
+ atol,
+ rtol,
+ rmse_tol,
)
+ _error(
+ fused_attn_fwd_fp8,
+ fused_attn_fwd_f16,
+ "fused_attn_fwd_fp8",
+ "fused_attn_fwd_f16",
+ atol,
+ rtol,
+ rmse_tol,
)
- logging.debug("fused_attn_fwd RMSE: {:.6f}".format(fwd_rmse))
- try:
- torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols)
- except Exception as e:
- logging.debug(e)
-
- assert (
- fwd_rmse < rmse_tol * fwd_range
- ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
- fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range
- )
- for i, _ in enumerate(fused_attn_bwd_f16):
- bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i])
- bwd_range = max(
- fused_attn_bwd_fp8[i].max().item(), fused_attn_bwd_f16[i].max().item()
- ) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_f16[i].min().item())
-
- logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
- logging.debug(
- "fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}".format(
- i, fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()
- )
- )
- logging.debug(
- "fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}".format(
- i, fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()
+ if is_training:
+ for i, _ in enumerate(fused_attn_bwd_f16):
+ logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
+ _error(
+ fused_attn_bwd_fp8[i],
+ fused_attn_bwd_f16[i],
+ f"fused_attn_bwd_fp8[{i}]",
+ f"fused_attn_bwd_f16[{i}]",
+ atol,
+ rtol,
+ rmse_tol,
)
- )
- logging.debug("fused_attn_bwd RMSE[{}]: {:.6f}".format(i, bwd_rmse))
- try:
- torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols)
- except Exception as e:
- logging.debug(e)
-
- assert (
- bwd_rmse < rmse_tol * bwd_range
- ), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
- bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range
- )
-def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
+def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
@@ -1607,6 +1625,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
attention_type="self",
qkv_format=qkv_format,
).to(dtype=dtype, device="cuda")
+ if not is_training:
+ dpa = dpa.eval()
seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
@@ -1680,9 +1700,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=True,
)
- out.backward(out_grad)
+ if is_training:
+ out.backward(out_grad)
- return out, (inp[0].grad, inp[1].grad, inp[2].grad)
+ if is_training:
+ return out, (inp[0].grad, inp[1].grad, inp[2].grad)
+ return out, (None, None, None)
model_configs_fp8 = {
@@ -1726,58 +1749,26 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention")
- tols = dict(atol=5e-1, rtol=5e-1)
+ atol = 5e-1
+ rtol = 5e-1
rmse_tol = 0.1
- fwd_rmse = _rmse(fused_attn_fwd_fp8, unfused_attn_fwd_f16)
- fwd_range = max(fused_attn_fwd_fp8.max().item(), unfused_attn_fwd_f16.max().item()) - min(
- fused_attn_fwd_fp8.min().item(), unfused_attn_fwd_f16.min().item()
- )
- bwd_rmse = _rmse(fused_attn_bwd_fp8, unfused_attn_bwd_f16)
- bwd_range = max(fused_attn_bwd_fp8.max().item(), unfused_attn_bwd_f16.max().item()) - min(
- fused_attn_bwd_fp8.min().item(), unfused_attn_bwd_f16.min().item()
- )
-
- logging.debug(
- "fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format(
- fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item()
- )
- )
- logging.debug(
- "unfused_attn_fwd_f16 min {:.6f} max {:.6f}".format(
- unfused_attn_fwd_f16.min().item(), unfused_attn_fwd_f16.max().item()
- )
- )
- logging.debug("fused_attn_fwd_fp8 vs unfused_attn_fwd_f16 RMSE: {:.6f}".format(fwd_rmse))
- try:
- torch.testing.assert_close(fused_attn_fwd_fp8, unfused_attn_fwd_f16, **tols)
- except Exception as e:
- logging.debug(e)
-
- logging.debug(
- "fused_attn_bwd_fp8 min {:.6f} max {:.6f}".format(
- fused_attn_bwd_fp8.min().item(), fused_attn_bwd_fp8.max().item()
- )
- )
- logging.debug(
- "unfused_attn_bwd_f16 min {:.6f} max {:.6f}".format(
- unfused_attn_bwd_f16.min().item(), unfused_attn_bwd_f16.max().item()
- )
- )
- logging.debug("fused_attn_bwd_fp8 vs unfused_attn_bwd_f16 RMSE: {:.6f}".format(bwd_rmse))
- try:
- torch.testing.assert_close(fused_attn_bwd_fp8, unfused_attn_bwd_f16, **tols)
- except Exception as e:
- logging.debug(e)
-
- assert (
- fwd_rmse < rmse_tol * fwd_range
- ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
- fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range
+ _error(
+ fused_attn_fwd_fp8,
+ unfused_attn_fwd_f16,
+ "fused_attn_fwd_fp8",
+ "unfused_attn_fwd_f16",
+ atol,
+ rtol,
+ rmse_tol,
)
- assert (
- bwd_rmse < rmse_tol * bwd_range
- ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
- bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range
+ _error(
+ fused_attn_bwd_fp8,
+ unfused_attn_bwd_f16,
+ "fused_attn_bwd_fp8",
+ "unfused_attn_bwd_f16",
+ atol,
+ rtol,
+ rmse_tol,
)
@@ -1791,7 +1782,6 @@ def _run_custom_mha_fp8(dtype, config, backend):
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
- global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
inp = 0.0001 * torch.randint(
@@ -1846,7 +1836,6 @@ def _run_ref_mha_f16(dtype, config, backend):
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
- global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
inp = torch.load("qkv.pt").to(device="cuda")
@@ -1996,12 +1985,18 @@ def forward(
None,
None,
None,
- fp8_meta["scaling_fwd"].scale_inv[META_QKV],
- fp8_meta["scaling_fwd"].scale_inv[META_S],
- fp8_meta["scaling_fwd"].scale[META_S],
- fp8_meta["scaling_fwd"].scale[META_O],
- fp8_meta["scaling_fwd"].amax_history[0][META_S],
- fp8_meta["scaling_fwd"].amax_history[0][META_O],
+ fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv
+ META_QKV, # d_scale_qkv_offset
+ fp8_meta["scaling_fwd"].scale_inv, # d_scale_s
+ META_S, # d_scale_s_offset
+ fp8_meta["scaling_fwd"].scale, # q_scale_s
+ META_S, # q_scale_s_offset
+ fp8_meta["scaling_fwd"].scale, # q_scale_o
+ META_O, # q_scale_o_offset
+ fp8_meta["scaling_fwd"].amax_history, # amax_s
+ META_S, # amax_s_offset
+ fp8_meta["scaling_fwd"].amax_history, # amax_o
+ META_O, # amax_o_offset
attn_scale=None,
dropout=p_dropout,
fast_zero_fill=fast_zero_fill,
diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
index 82875e2791..c1c18ffe47 100644
--- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
+++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
@@ -22,10 +22,16 @@
"cp_1_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # MHA
- "cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
- "cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
+ "cp_1_3": ModelConfig(
+ 2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
+ ), # MHA
+ "cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
+ "cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(
- 2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
+ 2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
+ ), # GQA
+ "cp_2_3": ModelConfig(
+ 2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
), # GQA
}
@@ -45,31 +51,32 @@ def get_bash_arguments(**kwargs):
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
-@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"])
+@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"])
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
config = model_configs_flash_attn[model]
+ if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
+ pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
- pytest.skip(
- f"CP implementation with KV all-gather does not support {qkv_format} format yet!"
- )
- if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type:
- pytest.skip(
- f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask"
- " type yet!"
- )
+ pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
+ pytest.skip("CP implementation with KV all-gather does not support bias yet!")
+ if cp_comm_type == "a2a" and qkv_format == "thd":
+ pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
+ if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias":
+ pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
+ if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
pytest.skip(
- f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias"
- " type yet!"
- )
- if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
- pytest.skip(
- f"CP implementation with KV P2P does not support window size {config.window_size} yet!"
+ f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
+ f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
)
subprocess.run(
get_bash_arguments(
- dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FlashAttention"
+ dtype=dtype,
+ model=model,
+ qkv_format=qkv_format,
+ kernel_backend="FlashAttention",
+ cp_comm_type=cp_comm_type,
),
check=True,
)
@@ -81,10 +88,16 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
"cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA
- "cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
- "cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
- "cp_2_2": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
- "cp_2_3": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
+ "cp_1_4": ModelConfig(
+ 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
+ ), # MHA
+ "cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
+ "cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
+ "cp_2_2": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
+ "cp_2_3": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
+ "cp_2_4": ModelConfig(
+ 2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
+ ), # GQA
}
@@ -93,37 +106,29 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
-@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"])
+@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
if qkv_format == "thd" and get_device_compute_capability() < (9, 0):
- pytest.skip("THD format is only supported on sm90+.")
+ pytest.skip("THD format is only supported on sm90+!")
if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0):
- pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0")
+ pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!")
+ if dtype == "fp8" and get_device_compute_capability() < (9, 0):
+ pytest.skip("FP8 attention is only supported on sm90+!")
config = model_configs_fused_attn[model]
if qkv_format == "thd" and config.num_heads != config.num_gqa_groups:
- pytest.skip(f"{qkv_format} format does not support QGA/MQA yet!")
+ pytest.skip("THD format does not support QGA/MQA yet!")
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
- pytest.skip(f"{qkv_format} format does not support {config.attn_bias_type} bias type yet!")
- if cp_comm_type == "all_gather" and qkv_format == "thd":
- pytest.skip(
- f"CP implementation with KV all-gather does not support {qkv_format} format yet!"
- )
- if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type:
- pytest.skip(
- f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask"
- " type yet!"
- )
- if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
- pytest.skip(
- f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias"
- " type yet!"
- )
- if config.window_size != (-1, 0) and config.window_size != (-1, -1):
+ pytest.skip("THD format does not support post_scale_bias yet!")
+ if qkv_format == "thd" and cp_comm_type == "all_gather":
+ pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
+ if qkv_format == "thd" and cp_comm_type == "a2a":
+ pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
+ if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type != "a2a":
pytest.skip(
- "Fused attention does not support sliding window attention + context parallelism yet!"
+ "Sliding window attention only can be supported with the implementation of QKVO A2A!"
)
- if cp_comm_type == "all_gather" and dtype == "fp8":
+ if dtype == "fp8" and cp_comm_type == "all_gather":
pytest.skip(
"CP implementation with KV all-gather does not support FP8 + context parallelism yet!"
)
@@ -131,10 +136,25 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
pytest.skip("FP8 attention cannot work with THD format yet!")
if dtype == "fp8" and config.attn_bias_type != "no_bias":
pytest.skip("FP8 attention cannot work with bias yet!")
+ if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
+ pytest.skip("FP8 attention cannot work with sliding window yet!")
+ if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
+ pytest.skip("CP implementation with KV all-gather does not support bias yet!")
+ if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias":
+ pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
+ if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
+ pytest.skip(
+ f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
+ f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
+ )
subprocess.run(
get_bash_arguments(
- dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention"
+ dtype=dtype,
+ model=model,
+ qkv_format=qkv_format,
+ kernel_backend="FusedAttention",
+ cp_comm_type=cp_comm_type,
),
check=True,
)
diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py
index 0ea0319771..fd204f58c4 100644
--- a/tests/pytorch/test_float8tensor.py
+++ b/tests/pytorch/test_float8tensor.py
@@ -293,7 +293,7 @@ def test_transpose(
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8_t, x, **tols)
- # Caching test.
+ # Caching test
assert x_fp8._transpose_invalid, "Transpose cache must be invalid when not caching."
x_fp8 += 0.5
x = x_fp8.from_float8()
@@ -302,14 +302,13 @@ def test_transpose(
torch.testing.assert_close(x_fp8_t, x_t, **tols)
assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."
- # Inplace update test.
+ # Inplace update test
x_fp8 += 0.5
- assert x_fp8._transpose_invalid, "Transpose cache not invalidated properly."
+ assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."
x = x_fp8.from_float8()
- x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(fill_cache=True))
+ x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8._transpose)
x_t = x.transpose(0, 1)
torch.testing.assert_close(x_fp8_t, x_t, **tols)
- assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."
def test_serialization(
self,
diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py
index 3523e1cda5..e97dfe1efd 100644
--- a/tests/pytorch/test_fusible_ops.py
+++ b/tests/pytorch/test_fusible_ops.py
@@ -88,10 +88,7 @@ def make_reference_and_test_tensors(
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8:
- test = Float8Tensor.to_float8(test)
- test._transpose = test._data.reshape(-1, test.size(-1)).transpose(0, 1)
- test._transpose = test._transpose.contiguous()
- test._transpose_invalid = False
+ test = Float8Tensor.to_float8(test, with_transpose_cache=True)
elif test.data_ptr() == ref.data_ptr():
test = test.clone()
ref.copy_(test)
diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py
index 85cd4fc256..ad34b4996f 100644
--- a/tests/pytorch/test_numerics.py
+++ b/tests/pytorch/test_numerics.py
@@ -7,6 +7,7 @@
from typing import Dict, List, Optional
import pytest
import copy
+import random
import torch
import torch.nn as nn
@@ -30,6 +31,8 @@
TransformerLayer,
LayerNorm,
InferenceParams,
+ Fp8Padding,
+ Fp8Unpadding,
)
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm
@@ -354,6 +357,40 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
return (input > 0) * input * input
+class TorchGroupedLinearWithPadding(nn.Module):
+
+ def __init__(
+ self, num_gemms, in_features, out_features, bias, params_dtype, parallel_mode, fp8
+ ) -> None:
+ super().__init__()
+
+ self.padding = Fp8Padding(num_gemms)
+ self.linear_fn = GroupedLinear(
+ num_gemms,
+ in_features,
+ out_features,
+ bias=bias,
+ params_dtype=params_dtype,
+ parallel_mode=parallel_mode,
+ device="cuda",
+ )
+ self.unpadding = Fp8Unpadding(num_gemms)
+
+ self.fp8 = fp8
+
+ def forward(self, inp: torch.Tensor, m_splits: List[int]) -> torch.Tensor:
+ if self.fp8:
+ orig_m_splits = m_splits
+ inp, m_splits = self.padding(inp, m_splits)
+
+ out = self.linear_fn(inp, m_splits)
+
+ if self.fp8:
+ out = self.unpadding(out, orig_m_splits)
+
+ return out
+
+
_supported_act = {
"geglu": nn.GELU(approximate="tanh"),
"gelu": nn.GELU(approximate="tanh"),
@@ -1229,12 +1266,15 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False
)
inp_hidden_states.retain_grad()
- m = config.seq_len // 16
- dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
- dist.append(dist[-1]) # Manually add a zero
- m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
- m_splits = m_splits * 16
- assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
+ if num_gemms > 1:
+ m = config.seq_len // 16
+ dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
+ dist.append(dist[-1]) # Manually add a zero
+ m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
+ m_splits = m_splits * 16
+ assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
+ else:
+ m_splits = torch.tensor([config.seq_len])
with fp8_autocast(enabled=fp8):
if isinstance(block, GroupedLinear):
@@ -1316,7 +1356,7 @@ def test_grouped_linear_accuracy(
@pytest.mark.parametrize("parallel_mode", ["column", "row"])
def test_grouped_linear_accuracy_parallel_mode(parallel_mode):
- """Split the tests to reduce CI time"""
+ """Split the tests to save CI time"""
test_grouped_linear_accuracy(
dtype=torch.float32,
num_gemms=6,
@@ -1328,6 +1368,170 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode):
)
+def test_grouped_linear_accuracy_single_gemm():
+ """Split the tests to save CI time"""
+ test_grouped_linear_accuracy(
+ dtype=torch.float32,
+ num_gemms=1,
+ bs=2,
+ model=list(model_configs.keys())[0],
+ fp8=True,
+ fp8_model_params=True,
+ )
+
+
+def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False):
+
+ def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
+ """Padding tensor shapes to multiples of 16."""
+ padded_tokens_per_expert = [
+ (num_tokens + 15) // 16 * 16 for num_tokens in tokens_per_expert
+ ]
+ hidden_states = torch.split(hidden_states, tokens_per_expert)
+ padded_hidden_states = []
+ for hidden_state, actual_num_tokens, padded_num_tokens in zip(
+ hidden_states, tokens_per_expert, padded_tokens_per_expert
+ ):
+ padded_hidden_states.append(hidden_state)
+ if padded_num_tokens > actual_num_tokens:
+ pad_tensor = torch.zeros(
+ padded_num_tokens - actual_num_tokens,
+ hidden_state.shape[1],
+ dtype=hidden_state.dtype,
+ device=hidden_state.device,
+ )
+ padded_hidden_states.append(pad_tensor)
+ padded_hidden_states = torch.cat(padded_hidden_states, dim=0)
+ return padded_hidden_states, padded_tokens_per_expert
+
+ def _unpad_tensor_for_fp8(padded_hidden_states, actual_tokens_per_expert, tokens_per_expert):
+ inputmats = torch.split(
+ padded_hidden_states.view(-1, padded_hidden_states.shape[-1]), tokens_per_expert
+ )
+ hidden_states = torch.cat(
+ [
+ grad_output_mat[: actual_tokens_per_expert[i]]
+ for i, grad_output_mat in enumerate(inputmats)
+ ],
+ dim=0,
+ )
+
+ return hidden_states
+
+ def _generate_random_numbers(n, total_sum):
+ if n <= 0:
+ return []
+
+ # reset seed
+ random.seed(seed)
+
+ breaks = sorted(random.sample(range(1, total_sum), n - 1))
+ random_numbers = (
+ [breaks[0]]
+ + [breaks[i] - breaks[i - 1] for i in range(1, n - 1)]
+ + [total_sum - breaks[-1]]
+ )
+
+ return random_numbers
+
+ reset_rng_states()
+ if fp8:
+ FP8GlobalStateManager.reset()
+
+ inp_hidden_states = torch.randn(
+ (config.seq_len * bs, config.hidden_size),
+ dtype=dtype,
+ device="cuda",
+ requires_grad=True,
+ )
+ inp_hidden_states.retain_grad()
+
+ m_splits = _generate_random_numbers(num_gemms, config.seq_len * bs)
+
+ with fp8_autocast(enabled=fp8):
+ if isinstance(block, TorchGroupedLinearWithPadding):
+ out = block(inp_hidden_states, m_splits)
+ else:
+ if fp8:
+ padded_inp_hidden_states, padding_m_splits = _pad_tensor_for_fp8(
+ inp_hidden_states, m_splits
+ )
+ padded_inp_hidden_states = block(padded_inp_hidden_states, padding_m_splits)
+ out = _unpad_tensor_for_fp8(padded_inp_hidden_states, m_splits, padding_m_splits)
+ else:
+ out = block(inp_hidden_states, m_splits)
+
+ loss = out.sum()
+ loss.backward()
+
+ torch.cuda.synchronize()
+ outputs = [out, inp_hidden_states.grad]
+ for p in block.parameters():
+ if p.requires_grad:
+ outputs.append(p.grad)
+ return outputs
+
+
+@pytest.mark.parametrize("dtype", param_types)
+@pytest.mark.parametrize("num_gemms", [3, 6])
+@pytest.mark.parametrize("bs", batch_sizes)
+@pytest.mark.parametrize("model", model_configs.keys())
+@pytest.mark.parametrize("fp8", [True])
+@pytest.mark.parametrize("fp8_model_params", all_boolean)
+def test_padding_grouped_linear_accuracy(
+ dtype, num_gemms, bs, model, fp8, fp8_model_params, parallel_mode=None
+):
+ if fp8 and not fp8_available:
+ pytest.skip(reason_for_no_fp8)
+
+ config = model_configs[model]
+ if config.seq_len % 16 != 0 and fp8:
+ pytest.skip("FP8 requires sequence length to be divisible by 16.")
+
+ with fp8_model_init(enabled=fp8 and fp8_model_params):
+ grouped_linear = TorchGroupedLinearWithPadding(
+ num_gemms,
+ config.hidden_size,
+ 4 * config.hidden_size,
+ bias=False,
+ params_dtype=dtype,
+ parallel_mode=parallel_mode,
+ fp8=fp8,
+ ).eval()
+
+ with fp8_model_init(enabled=fp8 and fp8_model_params):
+ ref_grouped_linear = GroupedLinear(
+ num_gemms,
+ config.hidden_size,
+ 4 * config.hidden_size,
+ bias=False,
+ params_dtype=dtype,
+ parallel_mode=parallel_mode,
+ device="cuda",
+ ).eval()
+
+ # Share params
+ with torch.no_grad():
+ inner_grouped_linear = grouped_linear.linear_fn
+ for i in range(num_gemms):
+ setattr(
+ ref_grouped_linear,
+ f"weight{i}",
+ Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()),
+ )
+
+ outputs = _test_padding_grouped_linear_accuracy(
+ grouped_linear, num_gemms, bs, dtype, config, fp8
+ )
+ outputs_ref = _test_padding_grouped_linear_accuracy(
+ ref_grouped_linear, num_gemms, bs, dtype, config, fp8
+ )
+
+ # Shoule be bit-wise match
+ for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
+ torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
+
+
def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
reset_rng_states()
@@ -1845,7 +2049,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):
fp8_grouped_gemm(
A_fp8,
- scale_inv,
+ [scale_inv],
0, # A_offset
tex.DType.kFloat8E4M3,
B_fp8,
diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py
new file mode 100644
index 0000000000..ed25b96955
--- /dev/null
+++ b/tests/pytorch/test_permutation.py
@@ -0,0 +1,513 @@
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import torch
+import pytest
+from typing import Dict, List
+
+from transformer_engine.pytorch import moe_permute as te_permute, moe_unpermute as te_unpermute
+from transformer_engine.pytorch.utils import is_bf16_compatible
+from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
+from transformer_engine.pytorch.float8_tensor import Float8Tensor
+import transformer_engine_torch as tex
+
+
+seed = 1234
+torch.manual_seed(seed)
+torch.cuda.manual_seed(seed)
+
+
+def pytorch_permute(tokens, indices, num_out_tokens: int = None):
+ """
+ Permute the tokens based on the indices. Token with the same index will be grouped together.
+ The input indices shape is [tokens, top_k], it indicates which experts were selected by each token separately.
+
+ Args:
+ tokens: torch.Tensor
+ The input token tensor.
+ indices: torch.Tensor
+ The token to expert indices tensor, should have a shape of [num_tokens] or [num_tokens, topk].
+ num_out_tokens: int, optional
+ The effective output token count, when enabling the capacity factor, should equal the number of tokens not dropped.
+ By default, set to None, meaning no tokens are dropped.
+
+ Returns:
+ torch.Tensor:
+ The permuted tensor.
+ torch.Tensor:
+ The sorted_indices corresponding permuted tensor.
+ """
+ if indices.dim() == 1:
+ topk = 1
+ else:
+ topk = indices.size(1)
+ flatten_indices = indices.view(-1)
+ sorted_indices = torch.argsort(flatten_indices, stable=True)
+ num_out_tokens = num_out_tokens if num_out_tokens is not None else flatten_indices.size(0)
+
+ permuted_tokens = tokens.index_select(0, sorted_indices[:num_out_tokens] // topk)
+ return permuted_tokens, sorted_indices
+
+
+def pytorch_unpermute(
+ permuted_tokens: torch.Tensor,
+ sorted_indices: torch.Tensor,
+ probs: torch.Tensor = None,
+):
+ """
+ Unpermute a tensor of permuted tokens based on sorted indices, and optionally merge the tokens with their
+ corresponding probabilities.
+
+ Args:
+ permuted_tokens: torch.Tensor
+ The tensor of permuted tokens to be unpermuted.
+ sorted_indices: torch.Tensor
+ The tensor of sorted indices used to unpermute the tokens.
+ probs: torch.Tensor, optional
+ The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will
+ be merged with their respective probabilities.
+
+ Returns:
+ torch.Tensor:
+ The unpermuted tokens, optionally merged with probabilities.
+ """
+
+ if probs is not None:
+ # Unpermute and merge the tokens with their probabilities
+ num_unpermuted_tokens = probs.numel()
+ topk = probs.size(1)
+ else:
+ # Unpermute the tokens without merge
+ num_unpermuted_tokens = sorted_indices.size(0)
+ topk = 1
+ unpermuted_tokens = torch.zeros(
+ [num_unpermuted_tokens, permuted_tokens.shape[-1]],
+ dtype=permuted_tokens.dtype,
+ device=permuted_tokens.device,
+ )
+
+ unpermuted_tokens.index_copy_(0, sorted_indices[: permuted_tokens.size(0)], permuted_tokens)
+ unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1))
+ if probs is not None:
+ unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1)
+ unpermuted_tokens = unpermuted_tokens.sum(dim=1)
+ return unpermuted_tokens
+
+
+def dtype_tols(te_dtype: tex.DType) -> Dict[str, float]:
+ """Estimated tolerances for a datatype
+
+ Based on tolerances for torch.testing.assert_close.
+
+ """
+ if te_dtype == tex.DType.kFloat32:
+ return dict(rtol=1.0e-6, atol=1.0e-6)
+ if te_dtype == tex.DType.kFloat16:
+ return dict(rtol=3.0e-3, atol=1.0e-5)
+ if te_dtype == tex.DType.kBFloat16:
+ return dict(rtol=2.0e-2, atol=1.0e-5)
+ if te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3:
+ return dict(rtol=2.0e-1, atol=1.0e-1)
+ raise ValueError(f"Unsuppored dtype ({te_dtype})")
+
+
+def _test_permutation(
+ te_dtype,
+ num_tokens,
+ num_expert,
+ hidden_size,
+ topK,
+ num_out_tokens,
+ with_probs,
+ BENCHMARK=False,
+):
+ if not with_probs and topK > 1:
+ pytest.skip("Only permutations with topK=1 and without probabilities are supported.")
+
+ if topK > num_expert:
+ pytest.skip("topK should be smaller than the number of experts.")
+
+ if num_out_tokens == None:
+ num_out_tokens = num_tokens * topK
+
+ print(
+ f"token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}"
+ )
+
+ fp8 = False
+ # Convert TE dtypes to PyTorch dtypes
+ if te_dtype == tex.DType.kFloat32:
+ dtype = torch.float32
+ elif te_dtype == tex.DType.kFloat16:
+ dtype = torch.float16
+ elif te_dtype == tex.DType.kBFloat16:
+ dtype = torch.bfloat16
+ elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3):
+ dtype = torch.uint8
+ fp8 = True
+ else:
+ pytest.skip("Invalid dtype.")
+
+ if fp8:
+ permute_fwd_input = torch.rand(
+ size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
+ )
+ permute_bwd_input = torch.rand(
+ size=(num_out_tokens, hidden_size), dtype=torch.float32, device="cuda"
+ )
+ unpermute_bwd_input = torch.rand(
+ size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
+ )
+
+ permute_fwd_input = Float8Tensor.to_float8(
+ permute_fwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0)
+ )
+ permute_bwd_input = Float8Tensor.to_float8(
+ permute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0)
+ )
+ unpermute_bwd_input = Float8Tensor.to_float8(
+ unpermute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0)
+ )
+
+ pytorch_permute_fwd_input = permute_fwd_input.from_float8(torch.float16)
+ pytorch_permute_bwd_input = permute_bwd_input.from_float8(torch.float16)
+ pytorch_unpermute_bwd_input = unpermute_bwd_input.from_float8(torch.float16)
+ else:
+ pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
+ pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda()
+ pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
+
+ pytorch_permute_fwd_input.requires_grad_(True)
+
+ if num_tokens > 0:
+ indices = torch.stack([torch.randperm(num_expert)[:topK] for _ in range(num_tokens)])
+ else:
+ indices = torch.empty((num_tokens, topK))
+ indices = indices.to(torch.int32).cuda()
+
+ probs = None
+ if with_probs:
+ probs = torch.rand(num_tokens, topK).cuda()
+ row_sums = probs.sum(dim=1, keepdim=True)
+ probs = probs / row_sums
+ probs.requires_grad_(True)
+
+ ###################################################################################################################################
+ #
+ # PyTorch Permutation
+ #
+ ###################################################################################################################################
+ pytorch_permute_output, sorted_indices = pytorch_permute(
+ pytorch_permute_fwd_input, indices, num_out_tokens
+ )
+ pytorch_permute_output.backward(pytorch_permute_bwd_input, retain_graph=True)
+
+ pytorch_unpermute_fwd_input = pytorch_permute_output.detach()
+ pytorch_unpermute_fwd_input.requires_grad_(True)
+
+ pytorch_unpermute_output = pytorch_unpermute(
+ pytorch_unpermute_fwd_input, sorted_indices, probs=probs
+ )
+ pytorch_unpermute_output.backward(pytorch_unpermute_bwd_input, retain_graph=True)
+
+ ###################################################################################################################################
+ #
+ # TE Permutation
+ #
+ ###################################################################################################################################
+ te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach()
+ te_permute_fwd_input.requires_grad_(True)
+ te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach()
+
+ te_permute_output, row_id_map = te_permute(te_permute_fwd_input, indices, num_out_tokens)
+ te_permute_output.backward(te_permute_bwd_input, retain_graph=True)
+
+ te_probs = None
+ if with_probs:
+ te_probs = probs.detach()
+ te_probs.requires_grad_(True)
+ te_unpermute_fwd_input = te_permute_output.detach()
+ te_unpermute_fwd_input.requires_grad_(True)
+ te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach()
+
+ te_unpermute_output = te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs)
+ te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True)
+
+ ###################################################################################################################################
+ #
+ # Results Check
+ #
+ ###################################################################################################################################
+ tols = dtype_tols(te_dtype)
+
+ if fp8:
+ te_permute_output_ = te_permute_output.from_float8(torch.float32)
+ te_permute_fwd_input_grad = te_permute_fwd_input.grad.from_float8(torch.float32)
+ te_unpermute_output_ = te_unpermute_output.from_float8(torch.float32)
+ te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.from_float8(torch.float32)
+ else:
+ te_permute_output_ = te_permute_output.float()
+ te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
+ te_unpermute_output_ = te_unpermute_output.float()
+ te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float()
+
+ torch.testing.assert_close(
+ pytorch_permute_output.float(),
+ te_permute_output_,
+ msg=f"Mismatch in te_permute fwd",
+ )
+ torch.testing.assert_close(
+ pytorch_permute_fwd_input.grad.float(),
+ te_permute_fwd_input_grad,
+ msg=f"Mismatch in te_permute bwd",
+ **tols,
+ )
+ torch.testing.assert_close(
+ pytorch_unpermute_output.float(),
+ te_unpermute_output_,
+ msg=f"Mismatch in te_unpermute fwd",
+ **tols,
+ )
+ torch.testing.assert_close(
+ pytorch_unpermute_fwd_input.grad.float(),
+ te_unpermute_fwd_input_grad,
+ msg=f"Mismatch in te_unpermute bwd",
+ **tols,
+ )
+ if with_probs:
+ torch.testing.assert_close(
+ probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols
+ )
+
+ if not pytorch_permute_fwd_input.numel():
+ print("Empty pytorch_permute_fwd_input activation test passed.")
+ return
+
+ ###################################################################################################################################
+ #
+ # Benchmark
+ #
+ ###################################################################################################################################
+ def backward_wrapper(
+ act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False
+ ):
+ # Set forward_input.grad to None to avoid grad accumulation.
+ if accumulate_grad == False:
+ for i in forward_input:
+ i.grad = None
+ return act.backward(backward_input, retain_graph=retain_graph)
+
+ if BENCHMARK:
+ t1 = perf_test_cuda_kernel(
+ lambda: pytorch_permute(pytorch_permute_fwd_input, indices, num_out_tokens)
+ )
+ t2 = perf_test_cuda_kernel(
+ lambda: te_permute(te_permute_fwd_input, indices, num_out_tokens)
+ )
+ print(f"permute\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
+
+ t1 = perf_test_cuda_kernel(
+ lambda: backward_wrapper(
+ pytorch_permute_output,
+ pytorch_permute_bwd_input,
+ forward_input=[pytorch_permute_fwd_input],
+ retain_graph=True,
+ accumulate_grad=False,
+ )
+ )
+ t2 = perf_test_cuda_kernel(
+ lambda: backward_wrapper(
+ te_permute_output,
+ te_permute_bwd_input,
+ forward_input=[te_permute_fwd_input],
+ retain_graph=True,
+ accumulate_grad=False,
+ )
+ )
+ print(f"permute\t\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
+
+ t1 = perf_test_cuda_kernel(
+ lambda: pytorch_unpermute(pytorch_unpermute_fwd_input, sorted_indices, probs=probs)
+ )
+ t2 = perf_test_cuda_kernel(
+ lambda: te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs)
+ )
+ print(f"unpermute\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
+
+ t1 = perf_test_cuda_kernel(
+ lambda: backward_wrapper(
+ pytorch_unpermute_output,
+ pytorch_unpermute_bwd_input,
+ forward_input=(
+ [pytorch_unpermute_fwd_input, probs]
+ if with_probs
+ else [pytorch_unpermute_fwd_input]
+ ),
+ retain_graph=True,
+ accumulate_grad=False,
+ )
+ )
+ t2 = perf_test_cuda_kernel(
+ lambda: backward_wrapper(
+ te_unpermute_output,
+ te_unpermute_bwd_input,
+ forward_input=(
+ [te_unpermute_fwd_input, te_probs] if with_probs else [te_unpermute_fwd_input]
+ ),
+ retain_graph=True,
+ accumulate_grad=False,
+ )
+ )
+ print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
+
+
+def perf_test_cuda_kernel(cuda_kernel_fn):
+ if torch.cuda.is_available():
+ # create CUDA event
+ start_event = torch.cuda.Event(enable_timing=True)
+ end_event = torch.cuda.Event(enable_timing=True)
+
+ # warmup
+ for _ in range(50):
+ cuda_kernel_fn()
+
+ start_event.record()
+ for _ in range(100):
+ cuda_kernel_fn()
+ end_event.record()
+ torch.cuda.synchronize()
+
+ elapsed_time_ms = start_event.elapsed_time(end_event)
+ return elapsed_time_ms / 100
+ else:
+ pytest.skip("CUDA is not available.")
+
+
+# TE tensor dtypes
+_te_dtypes: List[tex.DType] = [tex.DType.kFloat32, tex.DType.kFloat16]
+if is_bf16_compatible():
+ _te_dtypes.append(tex.DType.kBFloat16)
+
+
+@pytest.mark.parametrize("te_dtype", _te_dtypes)
+@pytest.mark.parametrize("num_tokens", [4096])
+@pytest.mark.parametrize("num_expert", [8, 16])
+@pytest.mark.parametrize("hidden_size", [4096])
+@pytest.mark.parametrize("topK", [1, 2, 5])
+@pytest.mark.parametrize("num_out_tokens", [None, 2039])
+def test_permutation(
+ te_dtype,
+ num_tokens,
+ num_expert,
+ hidden_size,
+ topK,
+ num_out_tokens,
+):
+ with_probs = True
+ BENCHMARK = False
+
+ _test_permutation(
+ te_dtype=te_dtype,
+ num_tokens=num_tokens,
+ num_expert=num_expert,
+ hidden_size=hidden_size,
+ topK=topK,
+ num_out_tokens=num_out_tokens,
+ with_probs=with_probs,
+ BENCHMARK=BENCHMARK,
+ )
+
+
+# Only run FP8 tests on H100.
+fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
+
+
+@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
+@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
+@pytest.mark.parametrize("num_tokens", [2048])
+@pytest.mark.parametrize("num_expert", [8, 16])
+@pytest.mark.parametrize("hidden_size", [4096])
+@pytest.mark.parametrize("topK", [1, 2, 5])
+@pytest.mark.parametrize("num_out_tokens", [None, 2039])
+def test_permutation_fp8(
+ te_dtype,
+ num_tokens,
+ num_expert,
+ hidden_size,
+ topK,
+ num_out_tokens,
+):
+ with_probs = True
+ BENCHMARK = False
+
+ _test_permutation(
+ te_dtype=te_dtype,
+ num_tokens=num_tokens,
+ num_expert=num_expert,
+ hidden_size=hidden_size,
+ topK=topK,
+ num_out_tokens=num_out_tokens,
+ with_probs=with_probs,
+ BENCHMARK=BENCHMARK,
+ )
+
+
+@pytest.mark.parametrize("te_dtype", _te_dtypes)
+@pytest.mark.parametrize("num_tokens", [4096])
+@pytest.mark.parametrize("num_expert", [8, 16])
+@pytest.mark.parametrize("hidden_size", [4096])
+def test_permutation_topk1_no_probs(
+ te_dtype,
+ num_tokens,
+ num_expert,
+ hidden_size,
+):
+ topK = 1
+ num_out_tokens = None
+ with_probs = False
+ BENCHMARK = False
+
+ _test_permutation(
+ te_dtype=te_dtype,
+ num_tokens=num_tokens,
+ num_expert=num_expert,
+ hidden_size=hidden_size,
+ topK=topK,
+ num_out_tokens=num_out_tokens,
+ with_probs=with_probs,
+ BENCHMARK=BENCHMARK,
+ )
+
+
+def test_permutation_single_case():
+ print("GPU:", torch.cuda.get_device_name(0))
+
+ # te_dtype = tex.DType.kFloat32
+ # te_dtype = tex.DType.kFloat16
+ # te_dtype = tex.DType.kBFloat16
+ te_dtype = tex.DType.kFloat8E5M2
+ # te_dtype = tex.DType.kFloat8E4M3
+
+ num_tokens = 10
+ num_expert = 4
+ hidden_size = 16
+ topK = 2
+ num_out_tokens = num_tokens * topK - 1
+ with_probs = True
+ Benchmark = True
+
+ _test_permutation(
+ te_dtype=te_dtype,
+ num_tokens=num_tokens,
+ num_expert=num_expert,
+ hidden_size=hidden_size,
+ topK=topK,
+ num_out_tokens=num_out_tokens,
+ with_probs=with_probs,
+ BENCHMARK=Benchmark,
+ )
+
+
+if __name__ == "__main__":
+ test_permutation_single_case()
diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt
index 58bd4f828c..647d2c474d 100644
--- a/transformer_engine/common/CMakeLists.txt
+++ b/transformer_engine/common/CMakeLists.txt
@@ -48,6 +48,7 @@ set(transformer_engine_SOURCES)
list(APPEND transformer_engine_SOURCES
pycudnn.cpp
transformer_engine.cpp
+ common.cu
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
@@ -65,10 +66,12 @@ list(APPEND transformer_engine_SOURCES
layer_norm/ln_api.cpp
layer_norm/ln_bwd_semi_cuda_kernel.cu
layer_norm/ln_fwd_cuda_kernel.cu
+ permutation/permutation.cu
rmsnorm/rmsnorm_api.cpp
rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
rmsnorm/rmsnorm_fwd_cuda_kernel.cu
util/cast.cu
+ util/padding.cu
util/cuda_driver.cpp
util/cuda_runtime.cpp
util/rtc.cpp
diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py
index 46cfa9176a..4bcd1f8e27 100644
--- a/transformer_engine/common/__init__.py
+++ b/transformer_engine/common/__init__.py
@@ -47,20 +47,20 @@ def _get_sys_extension():
def _load_cudnn():
"""Load CUDNN shared library."""
-
+ # Attempt to locate cuDNN in Python dist-packages
lib_path = glob.glob(
os.path.join(
sysconfig.get_path("purelib"),
f"nvidia/cudnn/lib/libcudnn.{_get_sys_extension()}.*[0-9]",
)
)
-
if lib_path:
assert (
len(lib_path) == 1
), f"Found {len(lib_path)} libcudnn.{_get_sys_extension()}.x in nvidia-cudnn-cuXX."
return ctypes.CDLL(lib_path[0], mode=ctypes.RTLD_GLOBAL)
+ # Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set
cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH")
if cudnn_home:
libs = glob.glob(f"{cudnn_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True)
@@ -68,13 +68,14 @@ def _load_cudnn():
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
- cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
- if cuda_home:
- libs = glob.glob(f"{cuda_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True)
- libs.sort(reverse=True, key=os.path.basename)
- if libs:
- return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
+ # Attempt to locate cuDNN in CUDA_HOME, CUDA_PATH or /usr/local/cuda
+ cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
+ libs = glob.glob(f"{cuda_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True)
+ libs.sort(reverse=True, key=os.path.basename)
+ if libs:
+ return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
+ # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
@@ -91,14 +92,15 @@ def _load_library():
def _load_nvrtc():
"""Load NVRTC shared library."""
- cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
- if cuda_home:
- libs = glob.glob(f"{cuda_home}/**/libnvrtc.{_get_sys_extension()}*", recursive=True)
- libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs))
- libs.sort(reverse=True, key=os.path.basename)
- if libs:
- return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
-
+ # Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda
+ cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
+ libs = glob.glob(f"{cuda_home}/**/libnvrtc.{_get_sys_extension()}*", recursive=True)
+ libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs))
+ libs.sort(reverse=True, key=os.path.basename)
+ if libs:
+ return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
+
+ # Attempt to locate NVRTC via ldconfig
libs = subprocess.check_output("ldconfig -p | grep 'libnvrtc'", shell=True)
libs = libs.decode("utf-8").split("\n")
sos = []
@@ -109,6 +111,8 @@ def _load_nvrtc():
sos.append(lib.split(">")[1].strip())
if sos:
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
+
+ # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libnvrtc.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h
index fc93705dff..6184e235bd 100644
--- a/transformer_engine/common/activation/activation_template.h
+++ b/transformer_engine/common/activation/activation_template.h
@@ -27,7 +27,8 @@ void act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) {
reinterpret_cast(input.data.dptr),
reinterpret_cast(output->data.dptr),
reinterpret_cast(output->scale.dptr),
- reinterpret_cast(output->amax.dptr), tot_elts, {},
+ reinterpret_cast(output->amax.dptr),
+ reinterpret_cast(output->scale_inv.dptr), tot_elts, {},
stream);); // NOLINT(*)
); // NOLINT(*)
}
@@ -50,7 +51,8 @@ void dact_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream
reinterpret_cast(input.data.dptr),
reinterpret_cast(output->data.dptr),
reinterpret_cast(output->scale.dptr),
- reinterpret_cast(output->amax.dptr), tot_elts, {},
+ reinterpret_cast(output->amax.dptr),
+ reinterpret_cast(output->scale_inv.dptr), tot_elts, {},
stream);); // NOLINT(*)
); // NOLINT(*)
}
@@ -74,7 +76,8 @@ void gated_act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) {
reinterpret_cast(input.data.dptr),
reinterpret_cast(output->data.dptr),
reinterpret_cast(output->scale.dptr),
- reinterpret_cast(output->amax.dptr), output->data.shape[0],
+ reinterpret_cast(output->amax.dptr),
+ reinterpret_cast(output->scale_inv.dptr), output->data.shape[0],
output->data.shape[1], {},
stream);); // NOLINT(*)
); // NOLINT(*)
diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu
new file mode 100644
index 0000000000..4e95fc24de
--- /dev/null
+++ b/transformer_engine/common/common.cu
@@ -0,0 +1,32 @@
+/*************************************************************************
+ * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+#include
+
+#include "./common.h"
+#include "./utils.cuh"
+
+namespace transformer_engine {
+
+namespace {
+
+__global__ void __launch_bounds__(1)
+ update_tensor_scale_inv_kernel(const float* __restrict__ scale_ptr,
+ float* __restrict__ scale_inv_ptr) {
+ const float scale = scale_ptr == nullptr ? 1 : *scale_ptr;
+ reciprocal(scale_inv_ptr, scale);
+}
+
+} // namespace
+
+void update_tensor_scale_inv(Tensor* t, cudaStream_t stream) {
+ if (t->scale_inv.dptr != nullptr) {
+ update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>(
+ reinterpret_cast(t->scale.dptr), reinterpret_cast(t->scale_inv.dptr));
+ }
+}
+
+} // namespace transformer_engine
diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h
index 42b529f388..7e72e1b031 100644
--- a/transformer_engine/common/common.h
+++ b/transformer_engine/common/common.h
@@ -262,6 +262,13 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
bool is_fp8_dtype(const DType t);
+/*! \brief Update a tensor's FP8 scale-inverse
+ *
+ * The FP8 scale-inverse (dequantization scaling factor) is updated
+ * with the reciprocal of the FP8 scale (quantization scaling factor).
+ */
+void update_tensor_scale_inv(Tensor *t, cudaStream_t stream);
+
#define NVTE_API_CALL(api_name) \
transformer_engine::nvtx::NVTXWrapper _##api_name##_nvtx_wrapper(#api_name);
diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu
index c9b57752e2..593ec086d7 100644
--- a/transformer_engine/common/gemm/cublaslt_gemm.cu
+++ b/transformer_engine/common/gemm/cublaslt_gemm.cu
@@ -255,7 +255,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
"Unable to find suitable cuBLAS GEMM algorithm");
NVTE_CHECK_CUBLAS(status);
- if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms");
+ if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms");
// D = alpha * (A * B) + beta * C
NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc,
@@ -269,6 +269,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
workspace, /* workspace */
workspaceSize, stream)); /* stream */
+ // Update FP8 scale-inv in output tensor
+ if (is_fp8_dtype(outputD->data.dtype)) {
+ update_tensor_scale_inv(outputD, stream);
+ }
+
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc));
diff --git a/transformer_engine/common/include/transformer_engine/padding.h b/transformer_engine/common/include/transformer_engine/padding.h
new file mode 100644
index 0000000000..a419b38234
--- /dev/null
+++ b/transformer_engine/common/include/transformer_engine/padding.h
@@ -0,0 +1,51 @@
+/*************************************************************************
+ * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+/*! \file padding.h
+ * \brief Functions handling padding.
+ */
+
+#ifndef TRANSFORMER_ENGINE_PADDING_H_
+#define TRANSFORMER_ENGINE_PADDING_H_
+
+#include "transformer_engine.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/*! \brief Padding multiple tensors.
+ *
+ * NOTE: Padding mode only support bottom.
+ *
+ * For example, 3x3 matrix pad to 4x3 matrix.
+ *
+ * source
+ * | 1 | 2 | 3 |
+ * | 4 | 5 | 6 |
+ * | 7 | 8 | 9 |
+ *
+ * destination
+ * | 1 | 2 | 3 |
+ * | 4 | 5 | 6 |
+ * | 7 | 8 | 9 |
+ * | 0 | 0 | 0 |
+ *
+ * \param[in] num_tensors Number of tensors.
+ * \param[in] input_list List of 2D input tensors.
+ * \param[in,out] output_list List of padded tensors. Dimensions
+ * match tensors in input_list.
+ * \param[in] padded_num_rows_list List of padded num rows corresponding to input tensors.
+ * \param[in] stream CUDA stream used for the operation.
+ */
+void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list,
+ const int* padded_num_rows_list, cudaStream_t stream);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif
+
+#endif // TRANSFORMER_ENGINE_PADDING_H_
diff --git a/transformer_engine/common/include/transformer_engine/permutation.h b/transformer_engine/common/include/transformer_engine/permutation.h
new file mode 100644
index 0000000000..c6263bf87e
--- /dev/null
+++ b/transformer_engine/common/include/transformer_engine/permutation.h
@@ -0,0 +1,21 @@
+/*************************************************************************
+ * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+#ifndef TRANSFORMER_ENGINE_PERMUTATION_H_
+#define TRANSFORMER_ENGINE_PERMUTATION_H_
+
+#include "transformer_engine.h"
+
+void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor sorted_row_id,
+ NVTETensor row_id_map, const NVTETensor prob, NVTETensor prob_grad,
+ const NVTETensor input_fwd, const int num_rows, const int topK,
+ const int num_cols, const int num_out_tokens, cudaStream_t stream = nullptr);
+
+void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id_map,
+ const NVTETensor prob, const int num_rows, const int topK, const int num_cols,
+ cudaStream_t stream = nullptr);
+
+#endif // TRANSFORMER_ENGINE_PERMUTATION_H_
diff --git a/transformer_engine/common/layer_norm/ln.h b/transformer_engine/common/layer_norm/ln.h
index 45839ed75b..13543a10aa 100644
--- a/transformer_engine/common/layer_norm/ln.h
+++ b/transformer_engine/common/layer_norm/ln.h
@@ -89,6 +89,9 @@ struct FwdParams : public ParamsBase {
// AMax output
void *amax;
+ // Inverse of scaling factor
+ void *scale_inv;
+
// Whether to compute scale and amax
bool fp8_out;
};
diff --git a/transformer_engine/common/layer_norm/ln_api.cpp b/transformer_engine/common/layer_norm/ln_api.cpp
index 115422e94e..8a40450e59 100644
--- a/transformer_engine/common/layer_norm/ln_api.cpp
+++ b/transformer_engine/common/layer_norm/ln_api.cpp
@@ -196,6 +196,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
params.epsilon = epsilon;
params.amax = z->amax.dptr;
params.scale = z->scale.dptr;
+ params.scale_inv = z->scale_inv.dptr;
params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma;
diff --git a/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh b/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh
index 9fe4c16373..bd3741d1d1 100644
--- a/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh
+++ b/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh
@@ -132,10 +132,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
}
}
if (params.fp8_out) {
- amax = reduce_max(amax, warp);
- if (threadIdx.x == 0 && threadIdx.y == 0) {
- static_assert(std::is_same::value);
- atomicMaxFloat(reinterpret_cast(params.amax), amax);
+ // Reduce amax over block
+ if (params.amax != nullptr) {
+ amax = reduce_max(amax, warp);
+ if (threadIdx.x == 0) {
+ static_assert(std::is_same::value);
+ atomicMaxFloat(reinterpret_cast(params.amax), amax);
+ }
+ }
+
+ // Update scale-inverse
+ if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
+ reciprocal(reinterpret_cast(params.scale_inv), scale);
}
}
}
@@ -291,10 +299,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
// Finalize fp8 factors
if (params.fp8_out) {
- amax = reduce_max(amax, warp);
- if (threadIdx.x == 0) {
- static_assert(std::is_same::value);
- atomicMaxFloat(reinterpret_cast(params.amax), amax);
+ // Reduce amax over block
+ if (params.amax != nullptr) {
+ amax = reduce_max(amax, warp);
+ if (threadIdx.x == 0) {
+ static_assert(std::is_same::value);
+ atomicMaxFloat(reinterpret_cast(params.amax), amax);
+ }
+ }
+
+ // Update scale-inverse
+ if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
+ reciprocal(reinterpret_cast(params.scale_inv), scale);
}
}
}
diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu
new file mode 100644
index 0000000000..2b894fbfdc
--- /dev/null
+++ b/transformer_engine/common/permutation/permutation.cu
@@ -0,0 +1,369 @@
+/*************************************************************************
+ * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+#include
+
+#include "../common.h"
+
+static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id_map,
+ const int num_rows, const int topK,
+ const int num_out_tokens) {
+ // Each block corresponds to one source token
+ // row_id_map[topK][num_rows]
+ const int bid = blockIdx.x;
+ const int tid = threadIdx.x;
+ const int idx = bid * blockDim.x + tid;
+
+ if (idx >= num_rows * topK) return;
+
+ int source_row = sorted_row_id[idx];
+ int source_token_id = source_row / topK;
+ int source_topK_id = source_row % topK;
+
+ if (idx >= num_out_tokens) {
+ // Set the indices of dropped tokens to -1
+ row_id_map[source_topK_id * num_rows + source_token_id] = -1;
+ } else {
+ // Create a row id map for subsequent unpermute operation
+ row_id_map[source_topK_id * num_rows + source_token_id] = idx;
+ }
+}
+
+template
+__global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const int *row_id_map,
+ const float *prob, const int num_rows, const int topK,
+ const int num_cols) {
+ extern __shared__ int8_t s_mem[];
+ TCompute *s_prob = reinterpret_cast(s_mem);
+
+ // Each block corresponds to one dest token
+ const int source_token = blockIdx.x;
+ const int tid = threadIdx.x;
+
+ if (hasProb) {
+ for (int i = tid; i < topK; i += blockDim.x * blockDim.y) {
+ // Load all the topK probs related to the source row into smem
+ s_prob[i] = TCompute(prob[source_token * topK + i]);
+ }
+ __syncthreads();
+ }
+
+ // Register buffers for vector type (float4) memory access
+ float4 frag_load_store;
+ T *frag_load_store_ptr = reinterpret_cast(&frag_load_store);
+
+ // Number of elemments in frag_load_store
+ static constexpr int kElementsPerAccess = 16 / sizeof(T);
+
+ // Traverse along the hidden dimention
+ for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) {
+ TCompute frag_elem[kElementsPerAccess];
+ TCompute frag_sum[kElementsPerAccess];
+
+ int source_row = row_id_map[source_token];
+
+ // source_row == -1 represents a dropped token
+ if (source_row != -1) {
+ const T *source_row_ptr = input + source_row * num_cols;
+
+ frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i));
+
+ for (int e = 0; e < kElementsPerAccess; e++) {
+ frag_sum[e] = TCompute(frag_load_store_ptr[e]);
+ }
+
+ if (hasProb) {
+ for (int e = 0; e < kElementsPerAccess; e++) {
+ frag_sum[e] = frag_sum[e] * s_prob[0];
+ }
+ }
+ } else {
+ for (int e = 0; e < kElementsPerAccess; e++) {
+ frag_sum[e] = TCompute(0.0f);
+ }
+ }
+
+ for (int k = 1; k < topK; k++) {
+ source_row = row_id_map[k * num_rows + source_token];
+
+ if (source_row == -1) continue;
+
+ const T *source_row_ptr = input + source_row * num_cols;
+
+ frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i));
+
+ for (int e = 0; e < kElementsPerAccess; e++) {
+ frag_elem[e] = TCompute(frag_load_store_ptr[e]);
+ }
+
+ if (hasProb) {
+ for (int e = 0; e < kElementsPerAccess; e++) {
+ frag_elem[e] = frag_elem[e] * s_prob[k];
+ }
+ }
+
+ for (int e = 0; e < kElementsPerAccess; e++) {
+ frag_sum[e] = frag_sum[e] + frag_elem[e];
+ }
+ }
+
+ T *dest_row_ptr = unpermuted_output + source_token * num_cols;
+
+ for (int e = 0; e < kElementsPerAccess; e++) {
+ if constexpr ((std::is_same_v || std::is_same_v) &&
+ (!hasProb)) {
+ frag_sum[e] = frag_sum[e] / TCompute(topK);
+ }
+ frag_load_store_ptr[e] = T(frag_sum[e]);
+ }
+
+ *reinterpret_cast(dest_row_ptr + i) = frag_load_store;
+ }
+}
+
+template
+__global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *act_grad,
+ const float *prob, float *prob_grad, const int *row_id_map,
+ const int num_rows, const int topK, const int num_cols) {
+ extern __shared__ int8_t s_mem[];
+ TCompute *s_prob = reinterpret_cast(s_mem);
+
+ // Each block corresponds to one source token
+ const int source_token = blockIdx.x;
+ const int tid = threadIdx.x;
+
+ if (hasProb) {
+ for (int i = tid; i < topK; i += blockDim.x) {
+ // Load all the topK probs related to the source row into smem
+ s_prob[i] = TCompute(prob[source_token * topK + i]);
+ }
+ __syncthreads();
+ }
+
+ // Accumulators for the calculation of prob_grad
+ float accum[topKTile] = {0.0f};
+
+ // Register buffers for vector type (float4) memory access
+ float4 frag_load_store;
+ T *frag_load_store_ptr = reinterpret_cast(&frag_load_store);
+
+ // Number of elemments in frag_load_store
+ static constexpr int kElementsPerAccess = 16 / sizeof(T);
+
+ // The starting address of each source row
+ const T *source_row_ptr = input_bwd + source_token * num_cols;
+
+ // Traverse along the hidden dimention
+ for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) {
+ TCompute frag_src[kElementsPerAccess];
+
+ frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i));
+
+ for (int e = 0; e < kElementsPerAccess; e++) frag_src[e] = TCompute(frag_load_store_ptr[e]);
+
+ int index = source_token;
+
+ // Process each row in the corresponding topK rows
+ for (int k = 0; k < topKTile; k++) {
+ if (k == topK) break;
+
+ int dest_row = row_id_map[index];
+ index += num_rows;
+
+ if (dest_row != -1) {
+ if (hasProb) {
+ // Calculate act_grad in unpermute bwd
+ for (int e = 0; e < kElementsPerAccess; e++)
+ frag_load_store_ptr[e] = T(frag_src[e] * s_prob[k]);
+ } else {
+ // permute fwd
+ for (int e = 0; e < kElementsPerAccess; e++) frag_load_store_ptr[e] = T(frag_src[e]);
+ }
+
+ T *dest_row_ptr = act_grad + dest_row * num_cols;
+ *reinterpret_cast(dest_row_ptr + i) = frag_load_store;
+
+ if (hasProb) {
+ // Inner product calculation for prob_grad in unpermute bwd
+ const T *input_fwd_ptr = input_fwd + dest_row * num_cols;
+
+ frag_load_store = __ldlu(reinterpret_cast(input_fwd_ptr + i));
+
+ TCompute frag_input_fwd[kElementsPerAccess];
+ for (int e = 0; e < kElementsPerAccess; e++)
+ frag_input_fwd[e] = TCompute(frag_load_store_ptr[e]);
+
+ for (int e = 0; e < kElementsPerAccess; e++) {
+ accum[k] += static_cast(frag_src[e] * frag_input_fwd[e]);
+ }
+ }
+ }
+ }
+ }
+
+ if (hasProb) {
+ for (int k = 0; k < topKTile; k++) {
+ if (k == topK) break;
+ // Warp-level reduction
+ for (int mask = 16; mask > 0; mask /= 2) {
+ accum[k] = accum[k] + __shfl_xor_sync(0xffffffff, accum[k], mask, 32);
+ }
+ }
+
+ if (tid == 0) {
+ for (int k = 0; k < topKTile; k++) {
+ if (k == topK) break;
+ prob_grad[source_token * topK + k] = accum[k];
+ }
+ }
+ }
+}
+
+template
+void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, int *row_id_map,
+ const float *prob, float *prob_grad, const T *input_fwd,
+ const int num_rows, const int topK, const int num_cols,
+ const int num_out_tokens, cudaStream_t stream) {
+ using TCompute = typename std::conditional<(std::is_same::value ||
+ std::is_same::value),
+ half, T>::type;
+
+ static constexpr int kElementsPerAccess = 16 / sizeof(T);
+
+ if (input_fwd == nullptr) {
+ // moe_permute_fwd
+
+ int threads = 64;
+ int blocks = (num_rows * topK + threads - 1) / threads;
+
+ moe_permute_row_map<<>>(sorted_row_id, row_id_map, num_rows, topK,
+ num_out_tokens);
+
+ blocks = num_rows;
+ threads = std::min(num_cols / kElementsPerAccess, 1024);
+ moe_permute_kernel<<>>(
+ input, nullptr, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols);
+ } else {
+ // moe_unpermute_bwd
+
+ int threads = 32;
+ int blocks = num_rows;
+
+ if (prob == nullptr) {
+ // moe_unpermute_bwd without probs
+
+ moe_permute_kernel<<>>(
+ input, input_fwd, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols);
+ } else {
+ // moe_unpermute_bwd with probs
+
+ size_t smem_bytes = topK * sizeof(TCompute);
+
+ if (topK <= 8) {
+ moe_permute_kernel<<>>(
+ input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols);
+ } else if (topK <= 16) {
+ moe_permute_kernel<<>>(
+ input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols);
+ } else if (topK <= 32) {
+ moe_permute_kernel<<>>(
+ input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols);
+ } else if (topK <= 64) {
+ moe_permute_kernel<<>>(
+ input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols);
+ } else if (topK <= 128) {
+ moe_permute_kernel<<>>(
+ input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols);
+ } else {
+ NVTE_ERROR("topK cannot exceed 128.");
+ }
+ }
+ }
+}
+
+template
+void nvte_unpermute_launcher(const T *input, T *output, int *row_id_map, const float *prob,
+ const int num_rows, const int topK, const int num_cols,
+ cudaStream_t stream) {
+ using TCompute = typename std::conditional<(std::is_same::value ||
+ std::is_same::value),
+ half, T>::type;
+
+ static constexpr int kElementsPerAccess = 16 / sizeof(T);
+
+ int blocks = num_rows;
+ int threads = std::min(num_cols / kElementsPerAccess, 1024);
+ size_t smem_bytes = topK * sizeof(TCompute);
+
+ if (prob == nullptr) {
+ // moe_permute_bwd
+ // moe_unpermute_fwd without probs
+
+ moe_unpermute_kernel<<>>(
+ input, output, row_id_map, nullptr, num_rows, topK, num_cols);
+ } else {
+ // moe_unpermute_fwd with probs
+
+ moe_unpermute_kernel<<>>(
+ input, output, row_id_map, prob, num_rows, topK, num_cols);
+ }
+}
+
+void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor sorted_row_id,
+ NVTETensor row_id_map, const NVTETensor prob, NVTETensor prob_grad,
+ const NVTETensor input_fwd, const int num_rows, const int topK,
+ const int num_cols, const int num_out_tokens, cudaStream_t stream) {
+ NVTE_API_CALL(nvte_permute);
+
+ const transformer_engine::Tensor *input_cu =
+ reinterpret_cast(input);
+ const transformer_engine::Tensor *output_cu =
+ reinterpret_cast(output);
+ const transformer_engine::Tensor *sorted_row_id_cu =
+ reinterpret_cast(sorted_row_id);
+ const transformer_engine::Tensor *row_id_map_cu =
+ reinterpret_cast(row_id_map);
+ const transformer_engine::Tensor *prob_cu =
+ reinterpret_cast(prob);
+ const transformer_engine::Tensor *prob_grad_cu =
+ reinterpret_cast(prob_grad);
+ const transformer_engine::Tensor *input_fwd_cu =
+ reinterpret_cast(input_fwd);
+
+ TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
+ input_cu->data.dtype, T,
+ nvte_permute_launcher(reinterpret_cast(input_cu->data.dptr),
+ reinterpret_cast(output_cu->data.dptr),
+ reinterpret_cast(sorted_row_id_cu->data.dptr),
+ reinterpret_cast(row_id_map_cu->data.dptr),
+ reinterpret_cast(prob_cu->data.dptr),
+ reinterpret_cast(prob_grad_cu->data.dptr),
+ reinterpret_cast(input_fwd_cu->data.dptr), num_rows, topK,
+ num_cols, num_out_tokens, stream););
+}
+
+void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id_map,
+ const NVTETensor prob, const int num_rows, const int topK, const int num_cols,
+ cudaStream_t stream) {
+ NVTE_API_CALL(nvte_unpermute);
+
+ const transformer_engine::Tensor *input_cu =
+ reinterpret_cast(input);
+ const transformer_engine::Tensor *output_cu =
+ reinterpret_cast(output);
+ const transformer_engine::Tensor *row_id_map_cu =
+ reinterpret_cast(row_id_map);
+ const transformer_engine::Tensor *prob_cu =
+ reinterpret_cast(prob);
+
+ TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
+ input_cu->data.dtype, T,
+ nvte_unpermute_launcher(reinterpret_cast(input_cu->data.dptr),
+ reinterpret_cast(output_cu->data.dptr),
+ reinterpret_cast(row_id_map_cu->data.dptr),
+ reinterpret_cast(prob_cu->data.dptr), num_rows, topK,
+ num_cols, stream););
+}
diff --git a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/rmsnorm/rmsnorm_api.cpp
index e9a6ff483d..9b143b2f85 100644
--- a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp
+++ b/transformer_engine/common/rmsnorm/rmsnorm_api.cpp
@@ -159,6 +159,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
params.epsilon = epsilon;
params.amax = z->amax.dptr;
params.scale = z->scale.dptr;
+ params.scale_inv = z->scale_inv.dptr;
params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma;
diff --git a/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh b/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh
index a1cfc2293c..c435ae3744 100644
--- a/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh
+++ b/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh
@@ -125,10 +125,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
}
}
if (params.fp8_out) {
- amax = reduce_max(amax, warp);
- if (threadIdx.x == 0 && threadIdx.y == 0) {
- static_assert(std::is_same::value);
- atomicMaxFloat(reinterpret_cast(params.amax), amax);
+ // Reduce amax over block
+ if (params.amax != nullptr) {
+ amax = reduce_max(amax, warp);
+ if (threadIdx.x == 0) {
+ static_assert(std::is_same::value);
+ atomicMaxFloat(reinterpret_cast(params.amax), amax);
+ }
+ }
+
+ // Update scale-inverse
+ if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
+ reciprocal(reinterpret_cast(params.scale_inv), scale);
}
}
}
@@ -267,10 +275,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
// Finalize fp8 factors
if (params.fp8_out) {
- amax = reduce_max(amax, warp);
- if (threadIdx.x == 0) {
- static_assert(std::is_same::value);
- atomicMaxFloat(reinterpret_cast(params.amax), amax);
+ // Reduce amax over block
+ if (params.amax != nullptr) {
+ amax = reduce_max(amax, warp);
+ if (threadIdx.x == 0) {
+ static_assert(std::is_same::value);
+ atomicMaxFloat(reinterpret_cast(params.amax), amax);
+ }
+ }
+
+ // Update scale-inverse
+ if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
+ reciprocal(reinterpret_cast(params.scale_inv), scale);
}
}
}
diff --git a/transformer_engine/common/transpose/cast_transpose.cu b/transformer_engine/common/transpose/cast_transpose.cu
index 6cbd4daade..dd45d0a668 100644
--- a/transformer_engine/common/transpose/cast_transpose.cu
+++ b/transformer_engine/common/transpose/cast_transpose.cu
@@ -101,14 +101,11 @@ struct KernelConfig {
};
template
-__global__ void __launch_bounds__(block_size)
- cast_transpose_general_kernel(const IType *__restrict__ const input,
- const CType *__restrict__ const noop,
- OType *__restrict__ const output_c,
- OType *__restrict__ const output_t,
- const CType *__restrict__ const scale_ptr,
- CType *__restrict__ const amax_ptr, const size_t row_length,
- const size_t num_rows) {
+__global__ void __launch_bounds__(block_size) cast_transpose_general_kernel(
+ const IType *__restrict__ const input, const CType *__restrict__ const noop,
+ OType *__restrict__ const output_c, OType *__restrict__ const output_t,
+ const CType *__restrict__ const scale_ptr, CType *__restrict__ const amax_ptr,
+ CType *__restrict__ const scale_inv_ptr, const size_t row_length, const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return;
// Vectorized load/store sizes
@@ -207,9 +204,15 @@ __global__ void __launch_bounds__(block_size)
if (amax_ptr != nullptr) {
amax = reduce_max(amax, tidy);
if (threadIdx.x == 0) {
+ static_assert(std::is_same::value);
atomicMaxFloat(amax_ptr, amax);
}
}
+
+ // Update scale-inverse
+ if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv_ptr != nullptr) {
+ reciprocal(scale_inv_ptr, scale);
+ }
}
} // namespace
@@ -255,6 +258,8 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output
"Cast and transposed outputs need to share amax tensor.");
NVTE_CHECK(cast_output.scale.dptr == transposed_output.scale.dptr,
"Cast and transposed outputs need to share scale tensor.");
+ NVTE_CHECK(cast_output.scale_inv.dptr == transposed_output.scale_inv.dptr,
+ "Cast and transposed outputs need to share scale-inverse tensor.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, InputType,
@@ -324,7 +329,9 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output
static_cast(cast_output.data.dptr),
static_cast(transposed_output.data.dptr),
static_cast(cast_output.scale.dptr),
- static_cast(cast_output.amax.dptr), row_length, num_rows);
+ static_cast(cast_output.amax.dptr),
+ static_cast(cast_output.scale_inv.dptr), row_length,
+ num_rows);
} else { // Statically-compiled general kernel
constexpr size_t load_size = 4;
constexpr size_t store_size = 4;
@@ -339,7 +346,8 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output
static_cast(cast_output.data.dptr),
static_cast(transposed_output.data.dptr),
static_cast(cast_output.scale.dptr),
- static_cast(cast_output.amax.dptr), row_length, num_rows);
+ static_cast(cast_output.amax.dptr),
+ static_cast(cast_output.scale_inv.dptr), row_length, num_rows);
}); // NOLINT(*)
); // NOLINT(*)
}
diff --git a/transformer_engine/common/transpose/cast_transpose_fusion.cu b/transformer_engine/common/transpose/cast_transpose_fusion.cu
index d96757990f..a8361d57ea 100644
--- a/transformer_engine/common/transpose/cast_transpose_fusion.cu
+++ b/transformer_engine/common/transpose/cast_transpose_fusion.cu
@@ -433,15 +433,19 @@ __global__ void __launch_bounds__(cast_transpose_num_threads)
}
}
- /* warp tile amax reduce*/
- amax = reduce_max(amax, warp_id);
-
- if (threadIdx.x == 0) {
- static_assert(std::is_same::value);
- if (param.amax != nullptr) {
+ // Reduce amax over block
+ if (param.amax != nullptr) {
+ amax = reduce_max(amax, warp_id);
+ if (threadIdx.x == 0) {
+ static_assert(std::is_same::value);
atomicMaxFloat(param.amax, amax);
}
}
+
+ // Update scale-inverse
+ if (blockIdx.x == 0 && threadIdx.x == 0 && param.scale_inv != nullptr) {
+ reciprocal(param.scale_inv, scale);
+ }
}
static const char *ActTypeToString[] = {
@@ -870,17 +874,18 @@ __global__ void __launch_bounds__(cast_transpose_num_threads)
__syncthreads();
}
- /* warp tile amax reduce*/
- max = reduce_max(max, warp_id);
-
- if (threadIdx.x == 0) {
- static_assert(std::is_same::value);
- if (amax != nullptr) {
+ // Reduce amax over block
+ if (amax != nullptr) {
+ max = reduce_max(max, warp_id);
+ if (threadIdx.x == 0) {
+ static_assert(std::is_same::value);
atomicMaxFloat(amax, max);
}
- if (scale_inv != nullptr) {
- reciprocal(scale_inv, scale);
- }
+ }
+
+ // Update scale-inverse
+ if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
+ reciprocal(scale_inv, scale);
}
}
@@ -1079,17 +1084,18 @@ __global__ void __launch_bounds__(cast_transpose_num_threads)
__syncthreads();
}
- /* warp tile amax reduce*/
- max = reduce_max(max, warp_id);
-
- if (threadIdx.x == 0) {
- static_assert(std::is_same::value);
- if (amax != nullptr) {
+ // Reduce amax over block
+ if (amax != nullptr) {
+ max = reduce_max(max, warp_id);
+ if (threadIdx.x == 0) {
+ static_assert(std::is_same::value);
atomicMaxFloat(amax, max);
}
- if (scale_inv != nullptr) {
- reciprocal(scale_inv, scale);
- }
+ }
+
+ // Update scale-inverse
+ if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
+ reciprocal(scale_inv, scale);
}
}
diff --git a/transformer_engine/common/transpose/multi_cast_transpose.cu b/transformer_engine/common/transpose/multi_cast_transpose.cu
index 8e6e90a7bf..4026016519 100644
--- a/transformer_engine/common/transpose/multi_cast_transpose.cu
+++ b/transformer_engine/common/transpose/multi_cast_transpose.cu
@@ -36,6 +36,8 @@ struct MultiCastTransposeArgs {
void* scale_list[kMaxTensorsPerKernel];
// (output) AMAX's of input tensors
void* amax_list[kMaxTensorsPerKernel];
+ // (output) Inverse of scaling factor for output tensors
+ void* scale_inv_list[kMaxTensorsPerKernel];
// Input matrix heights
int num_rows_list[kMaxTensorsPerKernel];
// Input matrix widths
@@ -82,7 +84,8 @@ __global__ void __launch_bounds__(threads_per_block)
OType* output_t = reinterpret_cast(args.output_t_list[tensor_id]);
const CType* scale_ptr = reinterpret_cast(args.scale_list[tensor_id]);
const CType scale = scale_ptr == nullptr ? 1 : *scale_ptr;
- CType* amax = reinterpret_cast(args.amax_list[tensor_id]);
+ CType* amax_ptr = reinterpret_cast(args.amax_list[tensor_id]);
+ CType* scale_inv_ptr = reinterpret_cast(args.scale_inv_list[tensor_id]);
const int num_rows = args.num_rows_list[tensor_id];
const int row_length = args.row_length_list[tensor_id];
@@ -183,7 +186,10 @@ __global__ void __launch_bounds__(threads_per_block)
local_amax = reduce_max(local_amax, tidy);
if (tid == 0) {
static_assert(std::is_same::value);
- if (amax != nullptr) atomicMaxFloat(amax, local_amax);
+ if (amax_ptr != nullptr) atomicMaxFloat(amax_ptr, local_amax);
+ }
+ if (tile_id == 0 && tid == 0 && scale_inv_ptr != nullptr) {
+ reciprocal(scale_inv_ptr, scale);
}
}
@@ -285,6 +291,7 @@ void multi_cast_transpose(const std::vector input_list,
kernel_args.output_t_list[pos] = transposed_output_list[tensor_id]->data.dptr;
kernel_args.scale_list[pos] = cast_output_list[tensor_id]->scale.dptr;
kernel_args.amax_list[pos] = cast_output_list[tensor_id]->amax.dptr;
+ kernel_args.scale_inv_list[pos] = cast_output_list[tensor_id]->scale_inv.dptr;
kernel_args.num_rows_list[pos] = num_rows;
kernel_args.row_length_list[pos] = row_length;
kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles;
diff --git a/transformer_engine/common/transpose/rtc/cast_transpose.cu b/transformer_engine/common/transpose/rtc/cast_transpose.cu
index 6ea8326147..07244a42e9 100644
--- a/transformer_engine/common/transpose/rtc/cast_transpose.cu
+++ b/transformer_engine/common/transpose/rtc/cast_transpose.cu
@@ -25,7 +25,7 @@ __global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel(
const IType* __restrict__ const input, const CType* __restrict__ const noop,
OType* __restrict__ const output_c, OType* __restrict__ const output_t,
const CType* __restrict__ const scale_ptr, CType* __restrict__ const amax_ptr,
- const size_t row_length, const size_t num_rows) {
+ CType* __restrict__ const scale_inv_ptr, const size_t row_length, const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return;
// Vectorized load/store sizes
@@ -121,4 +121,9 @@ __global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel(
atomicMaxFloat(amax_ptr, amax);
}
}
+
+ // Update scale-inverse
+ if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv_ptr != nullptr) {
+ reciprocal(scale_inv_ptr, scale);
+ }
}
diff --git a/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu b/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu
index c005be98ef..4ba1cb4c69 100644
--- a/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu
+++ b/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu
@@ -229,12 +229,16 @@ __global__ void __launch_bounds__(BLOCK_SIZE)
}
}
- // warp tile amax reduce
- const CType max_block = reduce_max(amax, warp_id);
-
- if (threadIdx.x == 0) {
- if (param.amax != nullptr) {
+ // Reduce amax over block
+ if (param.amax != nullptr) {
+ const CType max_block = reduce_max(amax, warp_id);
+ if (threadIdx.x == 0) {
atomicMaxFloat(param.amax, max_block);
}
}
+
+ // Update scale-inverse
+ if (blockIdx.x == 0 && threadIdx.x == 0 && param.scale_inv != nullptr) {
+ reciprocal(param.scale_inv, scale);
+ }
}
diff --git a/transformer_engine/common/util/cast.cu b/transformer_engine/common/util/cast.cu
index 8dd2b98ebf..dd03afd21b 100644
--- a/transformer_engine/common/util/cast.cu
+++ b/transformer_engine/common/util/cast.cu
@@ -46,7 +46,8 @@ void fp8_quantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
reinterpret_cast(input.data.dptr),
reinterpret_cast(output->data.dptr),
reinterpret_cast(output->scale.dptr),
- reinterpret_cast(output->amax.dptr), N, {},
+ reinterpret_cast(output->amax.dptr),
+ reinterpret_cast(output->scale_inv.dptr), N, {},
stream);); // NOLINT(*)
); // NOLINT(*)
}
@@ -68,7 +69,7 @@ void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
p.scale_inv = reinterpret_cast(input.scale_inv.dptr);
VectorizedUnaryKernelLauncher(
reinterpret_cast(input.data.dptr),
- reinterpret_cast(output->data.dptr), nullptr, nullptr, N, p,
+ reinterpret_cast(output->data.dptr), nullptr, nullptr, nullptr, N, p,
stream);); // NOLINT(*)
); // NOLINT(*)
}
diff --git a/transformer_engine/common/util/padding.cu b/transformer_engine/common/util/padding.cu
new file mode 100644
index 0000000000..017d2e6a56
--- /dev/null
+++ b/transformer_engine/common/util/padding.cu
@@ -0,0 +1,219 @@
+/*************************************************************************
+ * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+#include
+#include
+
+#include
+#include
+#include
+
+#include "../common.h"
+#include "../utils.cuh"
+
+namespace transformer_engine {
+
+namespace {
+
+// Parameters to tune
+constexpr int n_warps_per_tile = 4;
+constexpr int threads_per_block = THREADS_PER_WARP * n_warps_per_tile;
+constexpr int desired_load_store_size = 8;
+constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB
+
+struct MultiPaddingArgs {
+ // (input) Data buffers for input tensors
+ void* input_list[kMaxTensorsPerKernel];
+ // (output) Data buffers for cast output tensors
+ void* output_list[kMaxTensorsPerKernel];
+ // Input matrix heights
+ int num_rows_list[kMaxTensorsPerKernel];
+ // Input matrix heights (padded)
+ int padded_num_rows_list[kMaxTensorsPerKernel];
+ // Input matrix widths
+ int row_length_list[kMaxTensorsPerKernel];
+ // tensor
+ int block_range[kMaxTensorsPerKernel + 1];
+ // Number of tensors being processed by kernel
+ int num_tensors;
+};
+
+template
+__global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiPaddingArgs args) {
+ using Vec = Vec;
+
+ // Thread indices
+ // Note: Block is interpreted as a warp_size x num_warps grid
+ constexpr int bdimx = THREADS_PER_WARP;
+ constexpr int bdimy = n_warps_per_tile;
+ const int tid = threadIdx.x;
+ const int tidx = tid % bdimx;
+ const int tidy = tid / bdimx;
+ const int bid = blockIdx.x;
+
+ // Input tensors are divided into tiles
+ // Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles
+ constexpr int tile_dim_m = THREADS_PER_WARP * nvec;
+ constexpr int tile_dim_n = THREADS_PER_WARP * nvec;
+
+ // Number of nvec x nvec subtiles for each thread to
+ // load/store
+ constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
+
+ // Find tensor corresponding to block
+ int tensor_id = 0;
+ while (args.block_range[tensor_id + 1] <= bid) {
+ ++tensor_id;
+ }
+ const Type* input = reinterpret_cast(args.input_list[tensor_id]);
+ Type* output = reinterpret_cast(args.output_list[tensor_id]);
+ const int num_rows = args.num_rows_list[tensor_id];
+ const int padded_num_rows = args.padded_num_rows_list[tensor_id];
+ const int row_length = args.row_length_list[tensor_id];
+
+ // Find position of tile within tensor
+ const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n;
+ const int tile_id = bid - args.block_range[tensor_id];
+ const int tile_id_m = tile_id / num_tiles_n;
+ const int tile_id_n = tile_id % num_tiles_n;
+ const int tile_row = tile_id_m * tile_dim_m;
+ const int tile_col = tile_id_n * tile_dim_n;
+
+ // Load input and store to registers
+ // Note: Each thread loads n_iterations subtiles, casts to output
+ // type, and transposes in registers.
+ Type local_zero = static_cast(0.f);
+#pragma unroll
+ for (int iter = 0; iter < n_iterations; ++iter) {
+ const int i1 = tidy + iter * bdimy;
+ const int j1 = tidx;
+#pragma unroll
+ for (int i2 = 0; i2 < nvec; ++i2) {
+ const int row = tile_row + i1 * nvec + i2;
+ const int col = tile_col + j1 * nvec;
+ Vec local_input;
+ Vec local_output;
+ local_input.clear();
+ if (row < num_rows) {
+ for (int j2 = 0; j2 < nvec; ++j2) {
+ if (col + j2 < row_length) {
+ local_input.data.elt[j2] = input[row * row_length + col + j2];
+ }
+ }
+ }
+#pragma unroll
+ for (int j2 = 0; j2 < nvec; ++j2) {
+ local_output.data.elt[j2] = local_input.data.elt[j2];
+ }
+ if (row < num_rows) {
+ for (int j2 = 0; j2 < nvec; ++j2) {
+ if (col + j2 < row_length) {
+ output[row * row_length + col + j2] = local_output.data.elt[j2];
+ }
+ }
+ } else if (row < padded_num_rows) {
+ // padding
+ for (int j2 = 0; j2 < nvec; ++j2) {
+ if (col + j2 < row_length) {
+ output[row * row_length + col + j2] = local_zero;
+ }
+ }
+ }
+ }
+ }
+}
+
+} // namespace
+
+void multi_padding(const std::vector input_list, std::vector output_list,
+ const std::vector padded_num_rows_list, cudaStream_t stream) {
+ // Check that number of tensors is valid
+ NVTE_CHECK(output_list.size() == input_list.size(),
+ "Number of input and output tensors must match");
+ if (input_list.empty()) {
+ return;
+ }
+
+ // Check that tensor properties are valid
+ DType type = input_list[0]->data.dtype;
+ for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
+ const auto& input = *input_list[tensor_id];
+ const auto& output = *output_list[tensor_id];
+ CheckInputTensor(input, "multi_padding_input_" + std::to_string(tensor_id));
+ CheckInputTensor(output, "multi_padding_output_" + std::to_string(tensor_id));
+
+ NVTE_CHECK(input.data.dtype == type, "Input tensor types do not match.");
+ NVTE_CHECK(output.data.dtype == type, "Output tensor types do not match.");
+
+ NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions.");
+ NVTE_CHECK(output.data.shape[0] == padded_num_rows_list[tensor_id],
+ "output tensor shape does not match padded input shape.");
+ }
+
+ // Input matrices are divided into tiles
+ // Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles
+ const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size / typeToSize(type);
+ const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size / typeToSize(type);
+
+ // Add tensors to kernel argument struct
+ MultiPaddingArgs kernel_args;
+ kernel_args.num_tensors = 0;
+ kernel_args.block_range[0] = 0;
+ for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
+ // Launch kernel if argument struct is full
+ if (kernel_args.num_tensors == kMaxTensorsPerKernel) {
+ TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
+ type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type);
+ const int n_blocks = kernel_args.block_range[kernel_args.num_tensors];
+ multi_padding_kernel
+ <<>>(kernel_args);); // NOLINT(*)
+ kernel_args.num_tensors = 0;
+ }
+
+ // Calculate number of thread blocks needed for tensor
+ const int num_rows = input_list[tensor_id]->data.shape[0];
+ const int padded_num_rows = padded_num_rows_list[tensor_id];
+ const int row_length = input_list[tensor_id]->data.shape[1];
+ const int num_tiles_m = (padded_num_rows + tile_dim_m - 1) / tile_dim_m;
+ const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n;
+ const int num_tiles = num_tiles_m * num_tiles_n;
+
+ // Add tensor to kernel argument struct
+ const int pos = kernel_args.num_tensors;
+ kernel_args.input_list[pos] = const_cast(input_list[tensor_id]->data.dptr);
+ kernel_args.output_list[pos] = output_list[tensor_id]->data.dptr;
+ kernel_args.num_rows_list[pos] = num_rows;
+ kernel_args.padded_num_rows_list[pos] = padded_num_rows;
+ kernel_args.row_length_list[pos] = row_length;
+ kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles;
+ kernel_args.num_tensors++;
+ }
+
+ // Launch kernel
+ if (kernel_args.num_tensors > 0) {
+ TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
+ type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type);
+ const int n_blocks = kernel_args.block_range[kernel_args.num_tensors];
+ multi_padding_kernel
+ <<>>(kernel_args);); // NOLINT(*)
+ }
+}
+
+} // namespace transformer_engine
+
+void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list,
+ const int* padded_num_rows_list, cudaStream_t stream) {
+ NVTE_API_CALL(nvte_multi_padding);
+ using namespace transformer_engine;
+ std::vector input_list_, output_list_;
+ std::vector padded_num_rows_list_;
+ for (size_t i = 0; i < num_tensors; ++i) {
+ input_list_.push_back(reinterpret_cast(const_cast(input_list[i])));
+ output_list_.push_back(reinterpret_cast(output_list[i]));
+ padded_num_rows_list_.push_back(padded_num_rows_list[i]);
+ }
+ multi_padding(input_list_, output_list_, padded_num_rows_list_, stream);
+}
diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h
index 63ad1857cf..8653bf45a4 100644
--- a/transformer_engine/common/util/vectorized_pointwise.h
+++ b/transformer_engine/common/util/vectorized_pointwise.h
@@ -168,12 +168,12 @@ template
__launch_bounds__(unary_kernel_threads) __global__
void unary_kernel(const InputType *input, OutputType *output, const ComputeType *scale,
- ComputeType *amax, Param p, const size_t N,
+ ComputeType *amax, ComputeType *scale_inv, Param p, const size_t N,
const size_t num_aligned_elements) {
VectorizedLoader loader(input, N);
VectorizedStorer storer(output, N);
ComputeType max = 0;
- ComputeType s = 0;
+ ComputeType s = 1;
if constexpr (is_fp8::value) {
if (scale != nullptr) s = *scale;
}
@@ -199,12 +199,18 @@ __launch_bounds__(unary_kernel_threads) __global__
storer.store(tid, N);
}
if constexpr (is_fp8::value) {
- /* warp tile amax reduce*/
- max = reduce_max(max, warp_id);
+ // Reduce amax over block
+ if (amax != nullptr) {
+ max = reduce_max(max, warp_id);
+ if (threadIdx.x == 0) {
+ static_assert(std::is_same::value);
+ atomicMaxFloat(amax, max);
+ }
+ }
- if (threadIdx.x == 0 && amax != nullptr) {
- static_assert(std::is_same::value);
- atomicMaxFloat(amax, max);
+ // Update scale-inverse
+ if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
+ reciprocal(scale_inv, s);
}
}
}
@@ -214,13 +220,13 @@ template
__launch_bounds__(unary_kernel_threads) __global__
void unary_grad_kernel(const InputTypeGrad *grad, const InputType *input, OutputType *output,
- const ComputeType *scale, ComputeType *amax, Param p, const size_t N,
- const size_t num_aligned_elements) {
+ const ComputeType *scale, ComputeType *amax, ComputeType *scale_inv,
+ Param p, const size_t N, const size_t num_aligned_elements) {
VectorizedLoader loader(input, N);
VectorizedLoader grad_loader(grad, N);
VectorizedStorer storer(output, N);
ComputeType max = 0;
- ComputeType s = 0;
+ ComputeType s = 1;
if constexpr (is_fp8::value) {
if (scale != nullptr) s = *scale;
}
@@ -248,12 +254,18 @@ __launch_bounds__(unary_kernel_threads) __global__
storer.store(tid, N);
}
if constexpr (is_fp8::value) {
- /* warp tile amax reduce*/
- max = reduce_max(max, warp_id);
+ // Reduce amax over block
+ if (amax != nullptr) {
+ max = reduce_max(max, warp_id);
+ if (threadIdx.x == 0) {
+ static_assert(std::is_same::value);
+ atomicMaxFloat(amax, max);
+ }
+ }
- if (threadIdx.x == 0 && amax != nullptr) {
- static_assert(std::is_same