Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

与LLaVA官方代码训练结果性能相差较大 #5890

Open
1 task done
zhipeixu opened this issue Nov 1, 2024 · 0 comments
Open
1 task done

与LLaVA官方代码训练结果性能相差较大 #5890

zhipeixu opened this issue Nov 1, 2024 · 0 comments
Labels
pending This problem is yet to be addressed

Comments

@zhipeixu
Copy link

zhipeixu commented Nov 1, 2024

Reminder

  • I have read the README and searched the existing issues.

System Info

- `llamafactory` version: 0.9.1.dev0
- Platform: Linux-4.19.90-2107.6.0.0192.8.oe1.bclinux.x86_64-x86_64-with-glibc2.35
- Python version: 3.10.15
- PyTorch version: 2.4.1+cu121 (GPU)
- Transformers version: 4.45.2
- Datasets version: 2.21.0
- Accelerate version: 0.34.2
- PEFT version: 0.12.0
- TRL version: 0.9.6
- GPU type: NVIDIA A800 80GB PCIe
- DeepSpeed version: 0.15.2
- Bitsandbytes version: 0.44.1

Reproduction

### model
model_name_or_path: llava-hf/llava-1.5-13b-hf
# llava-hf/llava-v1.6-mistral-7b-hf
# llava-hf/llama3-llava-next-8b-hf
# 

### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all
lora_rank: 128
lora_alpha: 256

### dataset
dataset: CASIA2_vqa,CASIA2_Au_vqa
template: llava
cutoff_len: 4096
max_samples: 7000
overwrite_cache: true
preprocessing_num_workers: 16


### output
output_dir: saves/llava1_5-13b/lora/sft/1031-casia_df_aigc_vqa_direct_cot2_tf32
logging_steps: 1
save_strategy: epoch
save_steps: 400
plot_loss: true
overwrite_output_dir: true

### train
per_device_train_batch_size: 6
gradient_accumulation_steps: 1
learning_rate: 2.0e-4
num_train_epochs: 1
lr_scheduler_type: cosine
warmup_ratio: 0.03
weight_decay: 0.
gradient_checkpointing: true
# bf16: true
tf32: true
ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z3_config.json

### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: 'no'
eval_steps: 400

report_to: wandb
run_name: LLaVA1.5-13b-LORA-SFT


# llamafactory-cli train  /root/autodl-tmp/LLaMA-Factory/examples/train_lora/llava1_5_lora_sft.yaml

Expected behavior

我使用上面的config在llama-factory中进行训练后,得到的权重测试结果与用llava官方源代码训练权重测试结果性能相差较大(相同的训练集,在分类任务中,官方源码训练结果可以到90%准确率,但用llama-factory只有70%准确率),这是我使用的llava训练脚本:

deepspeed --include localhost:0,1,2,3  --master_port=29501  /root/autodl-tmp/LLaVA/llava/train/train_mem.py \
    --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
    --deepspeed /root/autodl-tmp/LLaVA/scripts/zero3.json \
    --model_name_or_path /root/autodl-tmp/modelzoo/llava-v1.5-13b \
    --version v1 \
    --data_path "${output_dir}/mix_train.json" \
    --image_folder / \
    --vision_tower openai/clip-vit-large-patch14-336 \
    --mm_projector_type mlp2x_gelu \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --image_aspect_ratio pad \
    --group_by_modality_length True \
    --bf16 True \
    --output_dir $output_dir \
    --num_train_epochs 2 \
    --per_device_train_batch_size 6 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 1 \
    --evaluation_strategy "no" \
    --save_strategy "epoch"\
    --save_steps 800 \
    --save_total_limit 1 \
    --learning_rate 2e-4 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 4096 \
    --gradient_checkpointing True \
    --dataloader_num_workers 4 \
    --lazy_preprocess True \
    --report_to wandb

Others

我有两个问题:

  1. 为什么二者会有这么大的差距,我的训练脚本是否有没有对齐的地方?
  2. 针对用MLLM做图片分类这个任务,有什么可以使用的训练技巧吗
    期待各位大佬解答
@github-actions github-actions bot added the pending This problem is yet to be addressed label Nov 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pending This problem is yet to be addressed
Projects
None yet
Development

No branches or pull requests

1 participant