Pytorch implementation of Vision Transformer. Pretrained pytorch weights are provided which are converted from original jax/flax weights. This is a project of the ASYML family and CASL.
Pytorch implementation of paper An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. We provide the pretrained pytorch weights which are converted from pretrained jax/flax models. We also provide fine-tune and evaluation script. Similar results as in original implementation are achieved.
Create environment:
conda create --name vit --file requirements.txt
conda activate vit
We provide pytorch model weights, which are converted from original jax/flax wieghts. You can download them and put the files under 'weights/pytorch' to use them.
Otherwise you can download the original jax/flax weights and put the fimes under 'weights/jax' to use them. We'll convert the weights for you online.
Currently three datasets are supported: ImageNet2012, CIFAR10, and CIFAR100. To evaluate or fine-tune on these datasets, download the datasets and put them in 'data/dataset_name'.
More datasets will be supported.
python src/train.py --exp-name ft --n-gpu 4 --tensorboard --model-arch b16 --checkpoint-path weights/pytorch/imagenet21k+imagenet2012_ViT-B_16.pth --image-size 384 --batch-size 32 --data-dir data/ --dataset CIFAR10 --num-classes 10 --train-steps 10000 --lr 0.03 --wd 0.0
Make sure you have downloaded the pretrained weights either in '.npy' format or '.pth' format
python src/eval.py --model-arch b16 --checkpoint-path weights/jax/imagenet21k+imagenet2012_ViT-B_16.npy --image-size 384 --batch-size 128 --data-dir data/ImageNet --dataset ImageNet --num-classes 1000
upstream | model | dataset | orig. jax acc | pytorch acc | model link |
---|---|---|---|---|---|
imagenet21k | ViT-B_16 | imagenet2012 | 84.62 | 83.90 | checkpoint |
imagenet21k | ViT-B_32 | imagenet2012 | 81.79 | 81.14 | checkpoint |
imagenet21k | ViT-L_16 | imagenet2012 | 85.07 | 84.94 | checkpoint |
imagenet21k | ViT-L_32 | imagenet2012 | 82.01 | 81.03 | checkpoint |
Due to limited GPU resources, the fine-tune results are obtained by using a batch size of 32 which may impact the performance a bit.
upstream | model | dataset | orig. jax acc | pytorch acc |
---|---|---|---|---|
imagenet21k | ViT-B_16 | CIFAR10 | 98.92 | 98.90 |
imagenet21k | ViT-B_16 | CIFAR100 | 92.26 | 91.65 |
- Colab
- Integrated into Texar
- https://github.com/google-research/vision_transformer
- https://github.com/lucidrains/vit-pytorch
- https://github.com/kamalkraj/Vision-Transformer
Issues and Pull Requests are welcome for improving this repo. Please follow the contribution guide