EasyR1 GRPO训练vlm模型尝试
注意:数据集在huggingface上存储为parquet格式,但是数据加载和处理的底层格式是arrow(arrow和parquet格式数据:https://blog.csdn.net/shizheng_Li/article/details/144132714)然而本地的.arrow格式数据集会报错,因此还是推送数据集到huggingface。最近要使用GRPO训一个vlm模型,听说easy R1
·
最近要使用GRPO训一个vlm模型,听说easy R1是一个比较稳定的框架,尝试一下
官方链接:https://github.com/hiyouga/EasyR1
- 先拉取镜像:
docker pull hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.4-flashinfer0.2.2-cxx11abi0
- 创建一个新文件,命名为Dockerfile(无扩展名),粘贴如下内容:
FROM hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.4-flashinfer0.2.2-cxx11abi0
WORKDIR /test1
RUN git clone https://github.com/hiyouga/EasyR1.git
WORKDIR /test1/EasyR1
RUN pip install -e . -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
RUN pip install swanlab -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
docker build -t kevinchina/deeplearning:EasyR1 .
自动寻找并基于当前目录下的 Dockerfile 构建镜像- 准备数据集:
参考Easy R1数据模板 https://huggingface.co/datasets/hiyouga/geometry3k
即单个数据样本符合模板:
yield {
"images": [image],
"problem": "<image>" + data["annotat_text"],
"answer": data["answer"],
}
并划分训练集和测试集
注意:数据集在huggingface上存储为parquet格式,但是数据加载和处理的底层格式是arrow(arrow和parquet格式数据:https://blog.csdn.net/shizheng_Li/article/details/144132714)然而本地的.arrow格式数据集会报错,因此还是推送数据集到huggingface
Hugging face登录:
- 在服务器
huggingface-cli login
- 输入:
<huggingface 令牌>
git config --global credential.helper store
将 Git 凭据(用户名和密码)永久保存在本地磁盘上,避免每次推送(git push)或拉取(git pull)时重复输入密码dataset_dict.push_to_hub("<user name>/<dataset name>")
- 构建一个容器:
docker run -it \
--net host \
--gpus '"device=0,1"' \
--shm-size=64g \
-v <本地模型地址>:/model \
kevinchina/deeplearning:EasyR1 \
bash
- 修改.sh文件
vim examples/qwen2_5_vl_3b_geo3k_grpo.sh
- 修改模型地址:
MODEL_PATH=/model
- 修改数据路径:
<user name>/<dataset name>
- 添加行
trainer.logger=['console','swanlab']
- 修改.yaml文件
vim examples/config.yaml
这个文件里参数比较多,重点修改bitch_size(决定了GPU的显存是否够用)、模型路径、输出日志logger、epoch、输出日志的项目名称、奖励模型、训练和验证数据集等,根据自己的情况吧。 - 修改奖励函数r1v.py的内容
vim examples/reward_function/r1v.py
如果是自定义的奖励函数,粘贴过来即可 - 设置环境变量并运行:
export SWANLAB_API_KEY=<your key> # 设置在线跟踪模式API
export SWANLAB_LOG_DIR=/swanlab_log # 设置本地日志存储路径
export SWANLAB_MODE=cloud # cloud云端跟踪模式
bash examples/qwen2_5_vl_3b_geo3k_grpo.sh
经过一段时间之后,就能在SWANLAB上看到训练情况了
处理报错:
out of memory
超出内存,batchsize等调小ValueError: Rollout batch size must be divisible by actor global batch size.
,即Rollout batch size( rollout 阶段的批次大小)不能被 actor global batch size(参与训练的全局批次大小)整除。将Rollout batch size设置为actor global batch size的整数倍既可pyarrow.lib.ArrowInvalid: Parquet magic bytes not found in footer. Either the file is corrupted or this is not a parquet file.
程序在尝试读取一个 Parquet 文件时失败了,具体原因是文件格式不正确或已损坏,.arrow格式数据集换成Parquet格式即可
更多推荐
所有评论(0)