[ICML 2024] Jetfire: Efficient and Accurate Transformer Pretraining with INT8 Data Flow and Per-Block Quantization
This repository provides the official PyTorch implementation for Jetfire: Efficient and Accurate Transformer Pretraining with INT8 Data Flow and Per-Block Qua.
- 2024.06: Our paper is selected as a spotlight paper!
- 2024.04: Jetfire is accepted by ICML 2024!
The Repocontains three main directories: INT8_GPT2
, Jetfire
, and JetfireGEMMKernel
.
Jetfire-INT8Training
│
├── INT8_GPT2 # A INT8 training recipe
│ ├── train.py
│ ├── qmodel.py
│ └── ...
│
├── Jetfire # Implementation of linear and non-linear operators
│ ├── Linear
│ └── Nonlinear
│
└── JetfireGEMMKernel # CUDA Kernels of GEMM
├── setup.py
├── BlockQuantize
└── ...
The INT8_GPT2
directory provides a recipe for INT8 training based on nanoGPT. It includes necessary scripts and configurations to enable INT8 training for GPT-2 models. To use INT8 training, modify the train.py
file as follows:
- Open
train.py
and locate line 36. - Change the code to
use_quantize_model=True
.
This will enable the INT8 training mode.
The training command is
cd INT8_GPT2
torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py
More details for training can be found in INT8_GPT2/README.md
.
The Jetfire
directory contains implementations for both linear and nonlinear operators. It is divided into two subdirectories: Linear
and Nonlinear
.
The Linear
subdirectory contains implementations that utilize CUDA kernels from the JetfireGEMMKernel
directory to perform linear operations in forward and backward process. The primary focus is on efficient matrix multiplications, which is introduced in Section 5 of our paper.
The Nonlinear
subdirectory contains implementations of nonlinear operators such as GELU
, LayerNorm
, Quantize
, and Stochastic Rounding
, leveraging Triton for optimal performance. This is introduced in Section 6 of our paper.
The JetfireGEMMKernel
directory includes CUDA kernels specifically designed for matrix multiplication operations. These kernels are utilized by the Linear
layer implementations in the Jetfire
directory to achieve high-performance linear operations.
To get started with this repository, clone it and install the GEMM kernels:
git clone https://github.com/thu-ml/Jetfire-INT8Training.git
cd Jetfire-INT8Training
cd JetfireGEMMKernel
python setup.py install
cd ..
To install triton, we use this specific version because the API might change:
pip install https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/07c94329-d4c3-4ad4-9e6b-f904a60032ec/pypi/download/triton-nightly/3.post20240610003544/triton_nightly-3.0.0.post20240610003544-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=ac2c36a49bf9c2bb780909b38096fb718f17efd78b88a1ca1d649f6d063cdc2c
For INT8 GPT-2 training, follow the instructions in the INT8_GPT2
section above. For developing or experimenting with linear and nonlinear operators, please explore the Jetfire
directories.
If you find our work helpful or interesting, please cite our work :)
@article{xi2024jetfire,
title={Jetfire: Efficient and Accurate Transformer Pretraining with INT8 Data Flow and Per-Block Quantization},
author={Xi, Haocheng and Chen, Yuxiang and Zhao, Kang and Zheng, Kaijun and Chen, Jianfei and Zhu, Jun},
journal={arXiv preprint arXiv:2403.12422},
year={2024}
}