Yuexiang Zhai*, Hao Bai†, Zipeng Lin†, Jiayi Pan†, Shengbang Tong†, Yifei Zhou†
Alane Suhr, Saining Xie, Yann LeCun, Yi Ma, Sergey Levine
*Project Lead, †Equal Contribution.
Paper | Project Page | Wandb Report | Data
- [Dec 15, 2024] We have archived our repo, due to lack of efforts in maintaining the package dependency. If you are interested in doing follow-up or setting your own codebase, some tips are: (1) always use the update-to-date backbone models (e.g., QWen, LLama3.2V, etc.); (2) design your own task by wrting environments; (3) customize your CoT for running RL.
- [Dec 13, 2024] Feel free to check our NeurIPS poster.
- [Sep 26, 2024] Our paper will appear at NeurIPS 2024.
- [Sep 13, 2024] (Important!) Our codebase does not specify the version of
tokenizers
, hence the hard-coded token ids for string'"action":'
(here) could be outdated for your tokenizer versions. We apologize for the inconvenience, and we suggest future users to manually check the token ids for future usage (or develop better strategies for obtaining these ids). - [Aug 7, 2024] We have uploaded a .zip file for the gym_cards environment. If you do not have the corresponding fonts, please consider downloading them.
- [June 7, 2024] We have prepared a template text wrapper to utilize our gym-cards environment in pure text. See examples here.
Our project contains three different codebases:
- A slightly modified version of LLaVA.
- See our
git diff
from the LLaVA branch here.
- See our
- Our original GymCards environment.
- The RL4VLM codebases for both the GymCards and ALFWorld environment.
Our training pipelines consists of two steps:
- Prepare for an SFT checkpoint.
- Check here to download the instruction-following data we prepared for running the initial SFT.
- We provide a template script (adapted from the official finetune.sh for the 1.6-mistral model) for running LLaVA sft. Please remember to set the
--data_path
,--image_folder
, and--output_dir
accordingly. - Please follow the instructions for LLaVA fine-tuning here.
- Our experiments start from the llava-1.6-mistral-7b checkpoint, you are welcome to use any initial models, but no guarantee to achieve a similar performance.
- Running RL using the SFT checkpoint.
-
For GymCards, please use these .sh run scripts.
- Check here for conda environment installation.
- [important] You may change the
num_processes
in config_zero2.yaml to the numbers of GPUs you have. Please make sure the number of GPUs you setCUDA_VISIBLE_DEVICES
in the.sh
file>=
thenum_processes
in config_zero2.yaml. - [important] If you only want to play around with our codebase, rather than reproduce our results. You may also skip the SFT from step 1, and directly use the llava1.6 model
liuhaotian/llava-v1.6-mistral-7b
as your initial model in--model-path
.
-
For ALFWorld please use this run file.
- Check here for conda environment installation.
- The
num_processes
in config_zero2.yaml and the number of GPUs in therun_alf.sh
file should follow the same rule as GymCards. We recommend only using 1 GPU to run ALFWorld, because the time for on-policy data collection largely varies across different GPUs, which may lead to NCCL time out during the synchronization of different threads with multiple GPUs.
-
This project is under the MIT License.
If you find our codebases useful, please consider citing our paper:
@inproceedings{
zhai2024finetuning,
title={Fine-Tuning Large Vision-Language Models as Decision-Making Agents via Reinforcement Learning},
author={Yuexiang Zhai and Hao Bai and Zipeng Lin and Jiayi Pan and Shengbang Tong and Yifei Zhou and Alane Suhr and Saining Xie and Yann LeCun and Yi Ma and Sergey Levine},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year={2024},
url={https://openreview.net/forum?id=nBjmMF2IZU}
}
Our codebases adopt LLaVA as a backbone model and apply PPO from this repo for RL fine-tuning. In principle, one may try to adapt our pipeline to different VLM / MLLM backbones and different RL algorithms.