人工智能讲师叶梓:大模型强化学习训练框架EasyR1 训练流程详解
本文详细介绍了EasyR1训练框架的执行流程,包含启动、初始化、Rollout、奖励计算、更新及验证保存等关键阶段。框架通过Ray分布式系统实现并行处理,采用vLLM生成响应,结合PPO算法进行模型优化。具体包括:1)命令行参数解析和Ray初始化;2)分词器、数据加载器、Worker组和奖励管理器的创建;3)vLLM生成响应序列;4)多阶段奖励和价值计算;5)基于FSDP的梯度更新;6)周期性验证
EasyR1 是一个专门用于大模型(含多模态)强化学习的开源训练框架,由 Zheng 等人于 2025 年开发,是 veRL的分支版本。该框架针对视觉-语言模型(VLMs/MLLMs)的强化学习训练进行了专门优化,能够同时处理视觉和语言信号,在统一的强化学习框架中进行端到端训练。EasyR1 的主要用途包括多模态推理增强(如数学推理、几何问题求解、图表理解)、图像编辑智能体、视频理解、文档 OCR 优化以及遥感图像分析等任务。该框架采用混合引擎设计,支持检查点恢复,与 Wandb 等工具无缝集成,主要使用 GRPO及其各种衍生算法进行策略优化,并提供完整的 RL 训练流程,包括 KL 正则化、奖励计算和分布式训练等功能。目前 EasyR1 已被广泛应用于视觉推理、图像编辑、视频理解等多个前沿研究领域,是当前多模态强化学习领域的重要工具之一,其开源实现可在 GitHub 上获取(https://github.com/hiyouga/EasyR1)。
本文详细描述了 EasyR1 训练框架的执行流程,包括每个步骤中调用的具体程序入口、函数和模块。
OpenClaw 火到爆,90% 人装不上!2026年4 月 11 日 17:30|叶梓老师免费直播零基础保姆级安装,命令行 / 环境坑一次全解。
总体流程
- 启动:运行 python3 -m verl.trainer.main config=examples/config.yaml ...
- 初始化:创建 Runner Ray actor,初始化 tokenizer、worker groups、reward manager
- Rollout 阶段:vLLM worker 为每个 prompt 生成 N 个响应
- 奖励/价值计算:计算 token 级奖励和 KL 惩罚
- 更新阶段:PPO 风格梯度更新(FSDP 支持)
- 验证和保存:周期性验证和检查点保存
1. 启动阶段:命令行入口
命令示例:
python3 -m verl.trainer.main config=examples/config.yaml \
data.train_files=hiyouga/math12k@train \
worker.actor.model.model_path=Qwen2.5-7B-Instruct \
trainer.experiment_name=exp1 \
trainer.n_gpus_per_node=8
调用流程:
- 入口函数:verl/trainer/main.py:99 的 main() 函数
- 调用 OmegaConf.from_cli() 解析命令行参数
- 加载 OmegaConf.structured(PPOConfig()) 默认配置
- 如果包含 config 参数,用 OmegaConf.load(config_path) 加载 YAML 文件
- 用 OmegaConf.merge() 合并默认配置、文件配置和 CLI 覆盖
- 调用 ppo_config.deep_post_init() 初始化配置对象(例如设置路径、填充长度等)
- Ray 初始化:verl/trainer/main.py:112-123
- 检查 ray.is_initialized()
- 未初始化则调用 ray.init(runtime_env=runtime_env),设置环境变量(NCCL_DEBUG、PYTHONUNBUFFERED 等)
- 启动 Runner:verl/trainer/main.py:125-126
- runner = Runner.remote() 创建一个 Ray remote actor
- ray.get(runner.run.remote(ppo_config)) 同步阻塞执行 Runner.run()
2. 初始化阶段:Runner.run()
调用函数:verl/trainer/main.py:30-96 的 Runner.run() 方法
2.1 Tokenizer/Processor 初始化
- Tokenizer:get_tokenizer()(verl/utils/tokenizer.py)
- 参数:config.worker.actor.model.model_path(模型路径)
- 返回 HuggingFace PreTrainedTokenizer 对象
- Processor:get_processor()(verl/utils/tokenizer.py)
- 返回 ProcessorMixin 对象(用于视觉模型,如 Qwen2.5-VL)
2.2 Worker 组配置
- RayWorkerGroup:verl/single_controller/ray/base.py 中的 RayWorkerGroup 类
- 角色映射(main.py:54-58):
- {
Role.ActorRollout: ray.remote(FSDPWorker), # Actor + Rollout 一体化
Role.Critic: ray.remote(FSDPWorker), # Critic 价值函数
Role.RefPolicy: ray.remote(FSDPWorker), # Reference 参考策略
}
- 资源池配置(main.py:59-68):
- ResourcePoolManager:管理 GPU 资源(每个节点上的 GPU 数量、多节点映射)
- resource_pool_spec:每个节点 GPU 数量列表
- mapping:每个角色对应的资源池名称(例如全部使用 “global_pool”)
2.3 Reward Manager 创建
- BatchFunctionRewardManager 或 SequentialFunctionRewardManager:
- 用户自定义奖励函数路径:config.worker.reward.reward_function
- 返回一个 Ray remote 的 FunctionRewardManager
调用示例(main.py:77-79):
RemoteRewardManager = ray.remote(RewardManager).options(num_cpus=config.worker.reward.num_cpus)
reward_fn = RemoteRewardManager.remote(config.worker.reward, tokenizer)
val_reward_fn = RemoteRewardManager.remote(config.worker.reward, tokenizer)
2.4 DataLoader 创建
函数:create_dataloader()(verl/trainer/data_loader.py:26-87)
调用位置(main.py:81):
train_dataloader, val_dataloader = create_dataloader(config.data, tokenizer, processor)
内部实现: - 创建 RLHFDataset 对象(verl/utils/dataset.py):
- 从 HuggingFace hub 或本地加载数据
- Tokenize prompts + images(tokenizer, processor)
- 创建 StatefulDataLoader 对象:
- sampler:随机或顺序采样(用于可恢复性)
- collate_fn:批量转换函数
- batch_size:config.data.rollout_batch_size
2.5 RayPPOTrainer 创建与初始化
函数:RayPPOTrainer.__init__()(verl/trainer/ray_trainer.py:164-254)
调用位置(main.py:83-96):
trainer = RayPPOTrainer(
config=config,
tokenizer=tokenizer,
processor=processor,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
)
trainer.init_workers() # 初始化所有 Worker Groups
trainer.fit() # 进入训练循环
__init__ 逻辑:
- 设置强化学习参数(KL 控制、Critic 使用、验证配置)
- 检查 batch size 可除性(rollout_batch_size 与 global_batch_size 之间的关系)
- 计算 training_steps(根据 epoch 或 max_steps)
init_workers() 逻辑(ray_trainer.py:325-393):
- resource_pool_manager.create_resource_pool():创建资源池(不直接创建 worker)
- 为每个角色构建 RayClassWithInitArgs:
- ActorRollout:FSDPWorker(config, role=“actor_rollout”)
- Critic:FSDPWorker(config, role=“critic”)
- RefPolicy:FSDPWorker(config, role=“ref”)
- create_colocated_worker_cls():封装多个 Ray Class
- RayWorkerGroup(resource_pool, ray_cls_with_init):
- resource_pool.get_placement_groups():获取 Placement Groups(多节点 GPU 分配)
- 为每个 GPU 和 Rank 创建并启动 remote actors:
- worker = ray_cls_with_init(placement_group=pg, placement_group_bundle_idx=local_rank, ...)
- wg_dict.spawn(prefix_set=class_dict.keys()):根据前缀(“actor_rollout”, “critic”, “ref”)生成对应的 worker group 接口(每个 worker 支持多个方法前缀,通过 spawn 拆分为独立的接口)
- init_model()(每个 worker group 的 Ray remote 方法):
- Actor/Critic/Ref Worker:
- _build_model_optimizer():加载 HuggingFace 模型,初始化 FSDP 封装、优化器、LR scheduler
- 根据配置 Offload 模型参数/优化器到 CPU - DataParallelPPOActor / DataParallelPPOCritic 封装
- ActorRollout 还会 _build_rollout():
- 初始化 vLLMRollout 对象(verl/workers/rollout/vllm_rollout_spmd.py),构建 vLLM 引擎(包括 KV cache 参数)
- RefPolicy 仅构建模型,不构建优化器
3. Rollout 阶段:生成 N 个响应
入口:RayPPOTrainer.fit() 的第一个内层循环(ray_trainer.py:455-514)
3.1 数据批次准备
调用流程(每训练 step):
- for batch_dict in self.train_dataloader:(ray_trainer.py:477)
- batch: DataProto = DataProto.from_single_dict(batch_dict)(verl/protocol.py)
- 弹出 generation 需要的 keys:
- gen_batch = batch.pop(batch_keys=["input_ids", "attention_mask", "position_ids"], ...)
3.2 生成序列(vLLM)
调用函数(带计时):ray_trainer.py:499-500
with timer("gen", timing_raw):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
相应执行链:
1. RayWorkerGroup(verl/single_controller/ray/base.py)的 generate_sequences() 方法:
- execute_all_sync() / execute_all_async() 将请求分发给所有 rank 对应的 remote worker
2. FSDPWorker(verl/workers/fsdp_workers.py)的 generate_sequences() 方法:
- 如果启用 offload,先 load_fsdp_model(self.fsdp_module)
- 使用 self.rollout_sharding_manager 管理 FSDP 与 vLLM 同步
- 调用 self.rollout.generate_sequences(prompts)(FSDPWorker.generate_sequences 内部)
- vLLMRollout.generate_sequences()(verl/workers/rollout/vllm_rollout_spmd.py:51):
- LLM 引擎(self.inference_engine)生成响应
- 返回 DataProto,包含 responses(response tokens)等字段
返回结果:gen_batch_output 包含每个 prompt 生成的 N 个响应(序列长度、注意力掩码等)
3.3 Prompt 重复和合并
- 生成全局唯一 UUID:batch.non_tensor_batch["uid"] = np.array([uuid.uuid4() for ...])
- Prompt 重复 N 次:batch = batch.repeat(repeat_times=self.config.worker.rollout.n, interleave=True)
- 合并生成结果:batch = batch.union(gen_batch_output)
3.4 序列长度分区(padding-free)
函数:_balance_batch()(ray_trainer.py:438-453)
调用位置:ray_trainer.py:528
self._balance_batch(batch, metrics=metrics)
具体流程:
- 获取 attention_mask,计算每个样本的有效 token 总数 global_seqlen_lst
- 调用 get_seqlen_balanced_partitions()(verl/utils/seqlen_balancing.py)按 seqlen 将样本分配到不同 DP rank
- batch.reorder(global_idx):对 DataProto 重排序
- 后续在请求分发时,每个 DP rank 处理 token 总量近似的样本,降低 padding 浪费
4. 奖励/价值计算阶段
4.1 奖励计算
调用函数:ray_trainer.py:535(异步) + ray_trainer.py:556(等待)
reward_ref = self.reward_fn.compute_reward.remote(batch)
# ... 在 compute_advantage 内:
reward_tensor, reward_metrics = ray.get(reward_ref)
batch.batch["token_level_scores"] = reward_tensor
内部实现:
- 指向 FunctionRewardManager.compute_reward()(verl/workers/reward/function.py:68-69)
- SequentialFunctionRewardManager:
- 对每个样本,tokenizer 解码 response → 调用用户奖励函数(reward_fn(response_str, ground_truth))
- 返回 token_level_rewards(只在最后一个 token 非零)+ reward_metrics
- BatchFunctionRewardManager:
- 批量解码所有 response → 调用 batch 奖励函数(reward_fn(response_list, label_list))
- 返回 Tuple[torch.Tensor, Dict],其中 tensor 形状为 (batch, seqlen),每个 token 位置有对应 reward
4.2 Old LogProbs 计算
调用函数:ray_trainer.py:539
old_log_probs = self.actor_rollout_wg.compute_log_probs(batch)
batch = batch.union(old_log_probs)
- RayWorkerGroup → FSDPWorker 的 compute_log_probs 方法:
- 使用 DataParallelPPOActor._forward_micro_batch() 计算每个 token 的 logprob(VF.log_probs_from_logits)
4.3 Ref LogProbs 计算(可选)
调用函数:ray_trainer.py:544-546
ref_log_probs = self.ref_policy_wg.compute_ref_log_probs(batch)
batch = batch.union(ref_log_probs)
- Ref Worker 仅用于 KL 控制;流程与 actor logprobs 类似(不使用 optimizer)
4.4 Value 计算(可选,仅 GAE)
调用函数:ray_trainer.py:550-552
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
- Critic Worker 输入整段序列,输出每个 token 的 value 估计
4.5 KL 惩罚(可选)
调用函数:apply_kl_penalty()(ray_trainer.py:114-131)
调用位置:ray_trainer.py:564
batch, kl_metrics = apply_kl_penalty(batch, self.kl_ctrl, self.config.algorithm.kl_penalty)
- 计算 old_log_probs 与 ref_log_probs 之间的 KL 散度(core_algos.compute_kl)
- batch.batch["token_level_rewards"] = token_level_scores - kl_ctrl.kl_coef * kld
- 更新 KL 系数(adaptive KL)和 metrics
5. 更新阶段:PPO 风格梯度更新
5.1 优势函数计算
函数:compute_advantage()(ray_trainer.py:134-161)
调用位置:ray_trainer.py:570-575
batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator, # e.g., "grpo"
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
)
核心逻辑:
- 根据 adv_estimator 选择不同实现(core_algos.py):
- GAE:调用 compute_gae_advantage_return()(需要 values)
- GRPO:调用 compute_grpo_outcome_advantage()(需要 token_level_rewards + uid 分组)
- ReMax:调用 compute_remax_outcome_advantage()(需要 reward_baselines)
- RLOO:调用 compute_rloo_outcome_advantage()(需要 uid 分组)
结果:
- batch.batch["advantages"]:token-level 优势
- batch.batch["returns"]:token-level 回报
5.2 Critic 更新(可选)
调用函数:ray_trainer.py:579-583
critic_output = self.critic_wg.update_critic(batch)
- RayWorkerGroup → FSDPWorker.update_critic → DataParallelPPOCritic.update_critic()(verl/workers/critic/dp_critic.py)
- 使用 MSE loss 拟合 value 网络与 returns
- FSDP 反向传播 + 优化器 step
- 返回 DataProto,内含 non_tensor_batch 指标(loss 等)
5.3 Actor 更新
调用函数:ray_trainer.py:587-591
if self.config.trainer.critic_warmup <= self.global_step:
actor_output = self.actor_rollout_wg.update_actor(batch)
- FSDPWorker.update_actor()(verl/workers/fsdp_workers.py:427-475):
- 如需 offload,先载入模型 load_fsdp_model() 和 optimizer load_fsdp_optimizer()
- 使用 ulysses_sharding_manager.preprocess_data() 对数据分片(序列并行)
- 调用 self.actor.update_policy(data)(DataParallelPPOActor.update_policy)
- PPO clipped surrogate loss + KL loss(如果启用)
- 反向传播和梯度裁剪
- 优化器 step
- LR scheduler step
- 收集指标(MFU、显存占用等)
- Offload 模型和优化器回 CPU
- 返回 DataProto 包含优化指标
6. 验证和保存阶段
6.1 验证(可选)
函数:_validate()(ray_trainer.py:274-323)
调用时机(fit() 内):
1. val_before_train=True:训练前验证(ray_trainer.py:470-474)
2. val_freq > 0 且 global_step % val_freq == 0:训练过程中周期验证(ray_trainer.py:594-602)
3. 训练结束后验证(ray_trainer.py:617-626)
验证流程:
- 遍历 self.val_dataloader - 对每个 batch,弹出 input_ids 等,构造 gen_batch,并覆写 meta_info 为 config.worker.rollout.val_override_config(例如 n=1, temperature=0.5)
- test_output_gen_batch = self.actor_rollout_wg.generate_sequences(test_gen_batch)
- val_reward_fn.compute_reward.remote(test_batch) 计算奖励
- self._maybe_log_val_generations() 选择部分样本记录(方便人工查看)
- 汇总 val/reward_score 等指标
6.2 检查点保存
函数:_save_checkpoint()(ray_trainer.py:395-414)
调用时机:
- trainer.save_freq > 0 且 global_step % save_freq == 0(ray_trainer.py:604-606)
- 训练结束后如果尚未保存(ray_trainer.py:628-629)
保存内容:
- Path:trainer.save_checkpoint_path/global_step_{step}/
- actor/:actor worker checkpoint(通过 FSDPCheckpointManager)
- critic/:critic worker checkpoint - dataloader.pt:dataloader state(保证恢复后仍保持采样顺序)
- Tracker token:写入 last_global_step.txt 防止冲突
恢复逻辑:
- _load_checkpoint():ray_trainer.py:416-436
- 读取 checkpoint 路径下的 actor/、critic/ 和 dataloader.pt,恢复 worker 模型和 dataloader 状态
6.3 Metrics 收集与日志
调用位置:ray_trainer.py:608-614
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, num_gpus=num_gpus))
self.logger.log(data=metrics, step=self.global_step)
- Tracker(verl/utils/logger.py):
- 支持 “console”、“wandb”、“swanlab”、“mlflow”、“tensorboard” 等后端
- 每 step 记录 reward、KL、timing、吞吐(token/s)、MFU 等所有指标
小结
上述流程可用一个简图概括:
用户命令
↓
verl/trainer/main.py:main()
↓
Runner.run() (Ray remote)
├─ get_tokenizer/get_processor
├─ create_dataloader (RLHFDataset, StatefulDataLoader)
├─ RewardManager.remote
└─ RayPPOTrainer
├─ init_workers
│ ├─ resource_pool_manager.create_resource_pool
│ ├─ RayClassWithInitArgs / RayWorkerGroup
│ ├─ wg.spawn → actor_rollout_wg / critic_wg / ref_policy_wg
│ └─ init_model (FSDPWorker)
│ ├─ _build_model_optimizer (FSDP + optim)
│ ├─ _build_rollout (vLLM)
│ └─ DataParallelPPOActor / DataParallelPPOCritic
└─ fit
└─ for epoch, for batch:
├─ batch = DataProto.from_single_dict
├─ gen_batch = batch.pop(...)
├─ generate_sequences (vLLM) → gen_batch_output
├─ batch = batch.union(gen_batch_output)
├─ _balance_batch (seqlen 分区)
├─ compute_reward (RewardManager)
├─ compute_log_probs (actor_rollout_wg)
├─ compute_ref_log_probs (ref_policy_wg)
├─ compute_values (critic_wg, if GAE)
├─ apply_kl_penalty
├─ compute_advantage → advantages / returns
├─ update_critic (if critic)
├─ update_actor (actor_rollout_wg)
├─ validate (if val_freq)
├─ save_checkpoint (if save_freq)
└─ logger.log(metrics)
更多推荐

所有评论(0)