Skip to content

Commit

Permalink
Release v1.11
Browse files Browse the repository at this point in the history
  • Loading branch information
ptrendx committed Oct 8, 2024
2 parents 08a85d3 + 4df8488 commit c27ee60
Show file tree
Hide file tree
Showing 118 changed files with 8,107 additions and 3,165 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/deploy_nightly_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Download artifact
uses: actions/download-artifact@v3
uses: actions/download-artifact@v4.1.7
with:
name: "te_docs"
path: "html"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
run: |
sudo apt-get update
sudo apt-get install pip -y
pip install torch
pip install torch numpy
export PYTHON_ONLY=1
export TE_PATH=.
bash ./qa/L0_pytorch_lint/test.sh
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/trigger-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ jobs:
|| github.actor == 'Oleg-Goncharov'
|| github.actor == 'phu0ngng'
|| github.actor == 'xrennvidia'
|| github.actor == 'yaox12'
|| github.actor == 'huanghua1994'
|| github.actor == 'mgoldfarb-nvidia'
)
steps:
- name: Check if comment is issued by authorized person
Expand Down
10 changes: 9 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,15 @@ To install the latest stable version of Transformer Engine,
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch).
This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch,paddle).

Alternatively, the package can be directly installed from `Transformer Engine's PyPI <https://pypi.org/project/transformer-engine/>`_, e.g.

.. code-block:: bash
pip install transformer_engine[pytorch]
To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch,paddle]). Transformer Engine ships wheels for the core library as well as the PaddlePaddle extensions. Source distributions are shipped for the JAX and PyTorch extensions.

From source
^^^^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/attention/benchmark_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def parse_results(per_cudnn, per_flash, model):
df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] = t_cudnn_avg.sum() / 1e6

if per_flash > 0:
t_flash_all = df[df["Name"].str.contains("void flash")]["Duration (ns)"].to_numpy()
t_flash_all = df[df["Name"].str.contains("flash")]["Duration (ns)"].to_numpy()
t_flash_all = t_flash_all.reshape(-1, per_flash)
t_flash_avg = np.average(t_flash_all, axis=0)
df_times.loc[row, "FlashAttention Kernels (fwd)"] = t_flash_avg[0] / 1e6
Expand Down
2 changes: 1 addition & 1 deletion build_tools/VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.10.0
1.11.0
8 changes: 6 additions & 2 deletions build_tools/build_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,12 @@ def run(self) -> None:
if isinstance(ext, CMakeExtension):
print(f"Building CMake extension {ext.name}")
# Set up incremental builds for CMake extensions
setup_dir = Path(__file__).resolve().parent.parent
build_dir = setup_dir / "build" / "cmake"
build_dir = os.getenv("NVTE_CMAKE_BUILD_DIR")
if build_dir:
build_dir = Path(build_dir).resolve()
else:
root_dir = Path(__file__).resolve().parent.parent
build_dir = root_dir / "build" / "cmake"

# Ensure the directory exists
build_dir.mkdir(parents=True, exist_ok=True)
Expand Down
21 changes: 13 additions & 8 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@

from .utils import (
all_files_in_dir,
cuda_version,
cuda_archs,
cuda_path,
cuda_version,
)


Expand Down Expand Up @@ -48,8 +49,6 @@ def setup_pytorch_extension(
]
nvcc_flags = [
"-O3",
"-gencode",
"arch=compute_70,code=sm_70",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
Expand All @@ -61,6 +60,11 @@ def setup_pytorch_extension(
"--use_fast_math",
]

cuda_architectures = cuda_archs()

if "70" in cuda_architectures:
nvcc_flags.extend(["-gencode", "arch=compute_70,code=sm_70"])

# Version-dependent CUDA options
try:
version = cuda_version()
Expand All @@ -73,17 +77,18 @@ def setup_pytorch_extension(
(
"--threads",
os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"),
"-gencode",
"arch=compute_80,code=sm_80",
"-gencode",
"arch=compute_90,code=sm_90",
)
)

if "80" in cuda_architectures:
nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
if "90" in cuda_architectures:
nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])

# Libraries
library_dirs = []
libraries = []
if os.getenv("NVTE_UB_WITH_MPI"):
if bool(int(os.getenv("NVTE_UB_WITH_MPI", 0))):
assert (
os.getenv("MPI_HOME") is not None
), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
Expand Down
7 changes: 6 additions & 1 deletion build_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

import functools
import glob
import importlib
import os
import re
import shutil
import subprocess
import sys
import importlib
from pathlib import Path
from subprocess import CalledProcessError
from typing import List, Optional, Tuple, Union
Expand Down Expand Up @@ -188,6 +188,11 @@ def cuda_path() -> Tuple[str, str]:
return cuda_home, nvcc_bin


@functools.lru_cache(maxsize=None)
def cuda_archs() -> str:
return os.getenv("NVTE_CUDA_ARCHS", "70;80;89;90")


def cuda_version() -> Tuple[int, ...]:
"""CUDA Toolkit version as a (major, minor) tuple."""
# Query NVCC for version info
Expand Down
15 changes: 12 additions & 3 deletions docs/_templates/layout.html
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
{% extends "!layout.html" %}

{% block extrahead %}

<script src="https://assets.adobedtm.com/5d4962a43b79/c1061d2c5e7b/launch-191c2462b890.min.js"></script>

{% endblock %}

{% block sidebartitle %} {{ super() }}

<style>
Expand Down Expand Up @@ -83,8 +90,10 @@
}
</style>

{%- if nvidia_analytics_id %}
<script type="text/javascript">_satellite.pageBottom();</script>
{%- endif %}
{% endblock %}

{% block footer %}

<script type="text/javascript">if (typeof _satellite !== “undefined”){ _satellite.pageBottom();}</script>

{% endblock %}
7 changes: 7 additions & 0 deletions docs/api/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ pyTorch
.. autoapiclass:: transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs)
:members: forward, set_tensor_parallel_group

.. autoapiclass:: transformer_engine.pytorch.GroupedLinear(in_features, out_features, bias=True, **kwargs)
:members: forward, set_tensor_parallel_group

.. autoapiclass:: transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs)

.. autoapiclass:: transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs)
Expand Down Expand Up @@ -44,3 +47,7 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables

.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context

.. autoapifunction:: transformer_engine.pytorch.moe_permute

.. autoapifunction:: transformer_engine.pytorch.moe_unpermute
9 changes: 6 additions & 3 deletions docs/examples/te_llama/te_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,11 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k
Custom method adapted from `from_pretrained` method in HuggingFace
Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
"""
vanilla_model = cls(config).to(kwargs["torch_dtype"])
is_local = os.path.isdir(pretrained_model_name_or_path)
# Before loading the model, set the default dtype for torch
torch.set_default_dtype(kwargs["torch_dtype"])

# Load the vanilla model weights
vanilla_model = cls(config)
subfolder = ""
variant = None
if os.path.isfile(
Expand Down Expand Up @@ -133,7 +136,7 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k
else:
raise AssertionError("Only sharded PyTorch ckpt format supported at the moment")

resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
resolved_archive_file, _ = get_checkpoint_shard_files(
pretrained_model_name_or_path,
archive_file,
)
Expand Down
65 changes: 46 additions & 19 deletions docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -247,15 +247,24 @@
"restart_jupyter_notebook()\n",
"\n",
"\n",
"# Import necessary packages and methods\n",
"# Import necessary packages, methods and variables\n",
"from utils import *\n",
"\n",
"\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n",
"# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n",
"# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
"# Provide Huggingface Access Token\n",
"hyperparams.hf_access_token = \"\"\n",
"assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"hyperparams.weights_cache_dir = \"\"\n",
"\n",
"# For Llama 2, uncomment this line (also set by default)\n",
"hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
"\n",
"# For Llama 3, uncomment this line\n",
"# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
"\n",
"hyperparams.mixed_precision = \"bf16\"\n",
"\n",
"\n",
Expand Down Expand Up @@ -554,7 +563,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "bdb34b91",
"metadata": {},
"outputs": [
Expand All @@ -573,15 +582,24 @@
"restart_jupyter_notebook()\n",
"\n",
"\n",
"# Import necessary packages and methods\n",
"# Import necessary packages, methods and variables\n",
"from utils import *\n",
"\n",
"\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n",
"# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n",
"# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
"# Provide Huggingface Access Token\n",
"hyperparams.hf_access_token = \"\"\n",
"assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"hyperparams.weights_cache_dir = \"\"\n",
"\n",
"# For Llama 2, uncomment this line (also set by default)\n",
"hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
"\n",
"# For Llama 3, uncomment this line\n",
"# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
"\n",
"hyperparams.mixed_precision = \"bf16\"\n",
"\n",
"\n",
Expand Down Expand Up @@ -653,15 +671,24 @@
"restart_jupyter_notebook()\n",
"\n",
"\n",
"# Import necessary packages and methods\n",
"# Import necessary packages, methods and variables\n",
"from utils import *\n",
"\n",
"\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n",
"# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n",
"# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
"# Provide Huggingface Access Token\n",
"hyperparams.hf_access_token = \"\"\n",
"assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"hyperparams.weights_cache_dir = \"\"\n",
"\n",
"# For Llama 2, uncomment this line (also set by default)\n",
"hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
"\n",
"# For Llama 3, uncomment this line\n",
"# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
"\n",
"hyperparams.mixed_precision = \"fp8\"\n",
"\n",
"\n",
Expand Down
Loading

0 comments on commit c27ee60

Please sign in to comment.