机器人抓取与操作学习(四):模型训练
以下命令用于执行一个 Python 脚本 imitate_episodes.py,该脚本主要用于训练或执行一个模仿学习 (imitation learning) 的过程。在命令行中提供的参数影响训练过程的各个方面。命令功能和用法该命令的目的是启动一个模仿学习的训练过程,主要用于训练模型以模仿专家的行为,特别是在一个仿真环境中进行抓取和放置立方体的任务。通过使用提供的参数,用户可以灵活地调整训练过程
以下命令用于执行一个 Python 脚本 imitate_episodes.py,该脚本主要用于训练或执行一个模仿学习 (imitation learning) 的过程。在命令行中提供的参数影响训练过程的各个方面。
python3 imitate_episodes.py \ # 使用 Python 3 执行 imitate_episodes.py 脚本
--task_name sim_pick_n_place_cube_scripted \ # 指定任务名称为 "sim_pick_n_place_cube_scripted",表示一个仿真任务,涉及抓取和放置立方体
--ckpt_dir ckpt_dir \ # 指定检查点目录为 "ckpt_dir",用于加载模型权重或保存训练过程中的检查点
--policy_class ACT \ # 指定策略类为 "ACT",表示使用的策略网络类
--kl_weight 10 \ # 设置 KL 散度权重为 10,衡量生成策略和专家策略之间的差异
--chunk_size 100 \ # 设置每个训练批次的大小为 100,控制每次训练中处理的数据量
--hidden_dim 512 \ # 设置神经网络隐藏层的维度为 512,影响模型的复杂性和表达能力
--batch_size 8 \ # 指定每次迭代中处理的样本数量为 8,控制训练的批量大小
--dim_feedforward 3200 \ # 设置前馈网络的维度为 3200,通常用于 Transformer 等模型的参数
--num_epochs 2000 \ # 指定训练的总轮数为 2000,控制训练时间和模型的学习程度
--lr 1e-5 \ # 设置学习率为 0.00001(1e-5),优化算法的步长
--seed 0 \ # 指定随机种子为 0,确保模型训练结果的可重复性
--temporal_agg # 该标志指示训练过程中使用时间聚合技术,帮助捕捉时序数据特征
命令功能和用法
该命令的目的是启动一个模仿学习的训练过程,主要用于训练模型以模仿专家的行为,特别是在一个仿真环境中进行抓取和放置立方体的任务。通过使用提供的参数,用户可以灵活地调整训练过程中的各种设置,从而优化模型的表现。
参数的意义
-
--task_name sim_pick_n_place_cube_scripted:- 这个参数指定了任务的名称,通常与具体的仿真环境或数据集相关。在这里,它表示一个模拟的抓取和放置立方体的任务。
-
--ckpt_dir ckpt_dir:- 指定了检查点目录。这是一个重要的参数,用于保存训练过程中生成的模型权重,以便后续加载和恢复训练或测试。如果指定的目录不存在,程序可能会报错。
-
--policy_class ACT:- 这个参数定义了所使用的策略类。在模仿学习中,策略网络负责生成动作以模仿专家的行为,
ACT可能是一个自定义的策略实现。
- 这个参数定义了所使用的策略类。在模仿学习中,策略网络负责生成动作以模仿专家的行为,
-
--kl_weight 10:- KL 散度权重用于衡量生成的策略与专家策略之间的差异。较大的 KL 权重值(如 10)可以鼓励模型更准确地模仿专家的行为。
-
--chunk_size 100:- 这个参数控制每个训练批次中处理的时间步数。将数据分成多个块(chunk)可以提高训练的效率,并有助于模型更快收敛。
-
--hidden_dim 512:- 设置神经网络隐藏层的维度。隐藏层越大,模型的表达能力越强,但也会增加计算复杂度和训练时间。
-
--batch_size 8:- 批量大小控制每次训练迭代中使用的样本数量。较小的批量大小可以提高模型的泛化能力,但训练时间会相应增加。
-
--dim_feedforward 3200:- 这个参数通常用于定义前馈网络的尺寸,影响模型在处理输入时的计算能力和深度。
-
--num_epochs 2000:- 指定训练的总轮数。训练过程中的每一轮都会遍历整个训练集,轮数越多,模型有更多机会学习数据中的模式,但可能会导致过拟合。
-
--lr 1e-5:- 学习率是优化算法中一个关键参数,决定模型参数更新的步长。学习率设置得过高可能导致训练不稳定,设置得过低则可能导致训练速度过慢。
-
--seed 0:- 随机种子用于确保实验的可重复性。通过设置相同的种子,可以确保每次运行程序时生成相同的随机数,从而得到一致的结果。
-
--temporal_agg:- 这个标志指示程序在训练过程中使用时间聚合技术。这种技术有助于更好地处理时序数据,捕捉时间序列中的相关特征。
该命令启动了一项复杂的模仿学习任务,涉及多种参数配置,以便在特定的模拟环境中训练模型。用户可以根据自己的需求调整这些参数,以优化训练过程和最终模型的性能。
下面是imitate_episodes.py的中文注释版本,并对程序结构、功能的详细讲解。
import torch
import numpy as np
import os
import pickle
import argparse
import matplotlib.pyplot as plt
from copy import deepcopy
from tqdm import tqdm
from einops import rearrange
# 导入常量
from constants import DT # 时间间隔
from constants import PUPPET_GRIPPER_JOINT_OPEN # 控制机器手爪打开的常量
# 导入数据加载和处理函数
from utils import load_data
from utils import sample_box_pose, sample_insertion_pose # 机器人功能函数
from utils import compute_dict_mean, set_seed, detach_dict # 辅助函数
from policy import ACTPolicy, CNNMLPPolicy # 导入两种策略
from visualize_episodes import save_videos # 可视化函数
from sim_env import BOX_POSE # 模拟环境
import IPython
e = IPython.embed # 交互式Python调试
def main(args):
set_seed(1) # 设置随机种子,确保可重复性
# 命令行参数
is_eval = args['eval'] # 是否为评估模式
ckpt_dir = args['ckpt_dir'] # 检查点目录
policy_class = args['policy_class'] # 策略类型
onscreen_render = args['onscreen_render'] # 是否在屏幕上渲染
task_name = args['task_name'] # 任务名称
batch_size_train = args['batch_size'] # 训练批量大小
batch_size_val = args['batch_size'] # 验证批量大小
num_epochs = args['num_epochs'] # 训练周期数
# 获取任务参数
is_sim = task_name[:4] == 'sim_' # 判断是否为模拟任务
if is_sim:
from constants import SIM_TASK_CONFIGS
task_config = SIM_TASK_CONFIGS[task_name] # 获取模拟任务配置
else:
from aloha_scripts.constants import TASK_CONFIGS
task_config = TASK_CONFIGS[task_name] # 获取真实任务配置
dataset_dir = task_config['dataset_dir'] # 数据集目录
num_episodes = task_config['num_episodes'] # 任务回合数
episode_len = task_config['episode_len'] # 每个回合的长度
camera_names = task_config['camera_names'] # 相机名称
# 固定参数
state_dim = 7 # 状态维度
lr_backbone = 1e-5 # 骨干学习率
backbone = 'resnet18' # 骨干网络结构
if policy_class == 'ACT':
enc_layers = 4 # 编码层数
dec_layers = 7 # 解码层数
nheads = 8 # 注意力头数
policy_config = {
'lr': args['lr'],
'num_queries': args['chunk_size'],
'kl_weight': args['kl_weight'],
'hidden_dim': args['hidden_dim'],
'dim_feedforward': args['dim_feedforward'],
'lr_backbone': lr_backbone,
'backbone': backbone,
'enc_layers': enc_layers,
'dec_layers': dec_layers,
'nheads': nheads,
'camera_names': camera_names,
}
elif policy_class == 'CNNMLP':
policy_config = {
'lr': args['lr'],
'lr_backbone': lr_backbone,
'backbone': backbone,
'num_queries': 1,
'camera_names': camera_names,
}
else:
raise NotImplementedError
config = {
'num_epochs': num_epochs,
'ckpt_dir': ckpt_dir,
'episode_len': episode_len,
'state_dim': state_dim,
'lr': args['lr'],
'policy_class': policy_class,
'onscreen_render': onscreen_render,
'policy_config': policy_config,
'task_name': task_name,
'seed': args['seed'],
'temporal_agg': args['temporal_agg'],
'camera_names': camera_names,
'real_robot': not is_sim
}
# 如果是评估模式
if is_eval:
ckpt_names = [f'policy_best.ckpt'] # 评估使用的检查点
results = []
for ckpt_name in ckpt_names:
success_rate, avg_return = eval_bc(config, ckpt_name, save_episode=True)
results.append([ckpt_name, success_rate, avg_return])
for ckpt_name, success_rate, avg_return in results:
print(f'{ckpt_name}: {success_rate=} {avg_return=}')
print()
exit()
# 加载数据
train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val)
# 保存数据集统计信息
if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir)
stats_path = os.path.join(ckpt_dir, f'dataset_stats.pkl')
with open(stats_path, 'wb') as f:
pickle.dump(stats, f)
# 训练模型
best_ckpt_info = train_bc(train_dataloader, val_dataloader, config)
best_epoch, min_val_loss, best_state_dict = best_ckpt_info
# 保存最佳检查点
ckpt_path = os.path.join(ckpt_dir, f'policy_best.ckpt')
torch.save(best_state_dict, ckpt_path)
print(f'最佳检查点,验证损失 {min_val_loss:.6f} @ epoch{best_epoch}')
def make_policy(policy_class, policy_config):
if policy_class == 'ACT':
policy = ACTPolicy(policy_config) # 创建 ACT 策略
elif policy_class == 'CNNMLP':
policy = CNNMLPPolicy(policy_config) # 创建 CNN-MLP 策略
else:
raise NotImplementedError
return policy
def make_optimizer(policy_class, policy):
if policy_class == 'ACT':
optimizer = policy.configure_optimizers() # 配置 ACT 优化器
elif policy_class == 'CNNMLP':
optimizer = policy.configure_optimizers() # 配置 CNN-MLP 优化器
else:
raise NotImplementedError
return optimizer
def get_image(ts, camera_names):
curr_images = []
for cam_name in camera_names:
curr_image = rearrange(ts.observation['images'][cam_name], 'h w c -> c h w') # 调整图像维度
curr_images.append(curr_image)
curr_image = np.stack(curr_images, axis=0) # 堆叠图像
curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0) # 转为张量并归一化
return curr_image
def eval_bc(config, ckpt_name, save_episode=True):
set_seed(1000) # 设置随机种子
ckpt_dir = config['ckpt_dir']
state_dim = config['state_dim']
real_robot = config['real_robot']
policy_class = config['policy_class']
onscreen_render = config['onscreen_render']
policy_config = config['policy_config']
camera_names = config['camera_names']
max_timesteps = config['episode_len']
task_name = config['task_name']
temporal_agg = config['temporal_agg']
onscreen_cam = 'angle'
# 加载策略和统计信息
ckpt_path = os.path.join(ckpt_dir, ckpt_name)
policy = make_policy(policy_class, policy_config)
loading_status = policy.load_state_dict(torch.load(ckpt_path)) # 加载模型权重
print(loading_status)
policy.cuda()
policy.eval()
print(f'加载成功: {ckpt_path}')
stats_path = os.path.join(ckpt_dir, f'dataset_stats.pkl')
with open(stats_path, 'rb') as f:
stats = pickle.load(f)
# 预处理和后处理函数
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
# 加载环境
if real_robot:
from aloha_scripts.robot_utils import move_grippers # 需要 aloha
from aloha_scripts.real_env import make_real_env # 需要 aloha
env = make_real_env(init_node=True)
env_max_reward = 0 # 真实机器人最大奖励
else:
from sim_env import make_sim_env
env = make_sim_env(task_name) # 创建模拟环境
env_max_reward = env.task.max_reward # 获取模拟环境最大奖励
query_frequency = policy_config['num_queries'] # 查询频率
if temporal_agg:
query_frequency = 1
num_queries = policy_config['num_queries']
max_timesteps = int(max_timesteps * 1) # 可能根据真实任务增加时间步数
num_rollouts = 50 # 总回合数
episode_returns = [] # 存储回合奖励
highest_rewards = [] # 存储最高奖励
for rollout_id in range(num_rollouts):
rollout_id += 0
### 设置任务
if 'sim_transfer_cube' in task_name:
BOX_POSE[0] = sample_box_pose() # 用于模拟重置
elif 'sim_pick_n_place_cube' in task_name:
BOX_POSE[0] = sample_box_pose()
elif 'sim_insertion' in task_name:
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # 用于模拟重置
ts = env.reset() # 重置环境
### 屏幕渲染
if onscreen_render:
ax = plt.subplot()
plt_img = ax.imshow(env._physics.render(height=480, width=640, camera_id=onscreen_cam))
plt.ion()
### 评估循环
if temporal_agg:
all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, state_dim]).cuda()
qpos_history = torch.zeros((1, max_timesteps, state_dim)).cuda() # 存储位置历史
image_list = [] # 用于可视化
qpos_list = []
target_qpos_list = []
rewards = []
with torch.inference_mode():
for t in range(max_timesteps):
### 更新屏幕渲染并等待时间间隔 DT
if onscreen_render:
image = env._physics.render(height=480, width=640, camera_id=onscreen_cam)
plt_img.set_data(image)
plt.pause(DT)
### 处理前一个时间步以获取位置和图像列表
obs = ts.observation
if 'images' in obs:
image_list.append(obs['images'])
else:
image_list.append({'main': obs['image']})
qpos_numpy = np.array(obs['qpos'])
qpos = pre_process(qpos_numpy) # 预处理关节位置
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0) # 转为张量并归一化
qpos_history[:, t] = qpos
curr_image = get_image(ts, camera_names) # 获取图像
### 查询策略
if config['policy_class'] == "ACT":
if t % query_frequency == 0:
all_actions = policy(qpos, curr_image) # 获取当前动作
if temporal_agg:
all_time_actions[[t], t:t+num_queries] = all_actions
actions_for_curr_step = all_time_actions[:, t]
actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
actions_for_curr_step = actions_for_curr_step[actions_populated]
k = 0.01
exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
exp_weights = exp_weights / exp_weights.sum()
exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
else:
raw_action = all_actions[:, t % query_frequency]
elif config['policy_class'] == "CNNMLP":
raw_action = policy(qpos, curr_image)
else:
raise NotImplementedError
### 后处理动作
raw_action = raw_action.squeeze(0).cpu().numpy()
action = post_process(raw_action) # 后处理动作
target_qpos = action
### 环境一步
ts = env.step(target_qpos)
### 可视化
qpos_list.append(qpos_numpy)
target_qpos_list.append(target_qpos)
rewards.append(ts.reward)
plt.close() # 关闭图像窗口
if real_robot:
move_grippers([env.puppet_bot_left, env.puppet_bot_right], [PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5) # 打开机械手爪
pass
rewards = np.array(rewards)
episode_return = np.sum(rewards[rewards != None]) # 计算回合奖励
episode_returns.append(episode_return)
episode_highest_reward = np.max(rewards) # 获取最高奖励
highest_rewards.append(episode_highest_reward)
print(f'回合 {rollout_id}\n{episode_return=}, {episode_highest_reward=}, {env_max_reward=}, 成功: {episode_highest_reward==env_max_reward}')
if save_episode:
save_videos(image_list, DT, video_path=os.path.join(ckpt_dir, f'video{rollout_id}.mp4')) # 保存视频
success_rate = np.mean(np.array(highest_rewards) == env_max_reward) # 计算成功率
avg_return = np.mean(episode_returns) # 计算平均回合奖励
summary_str = f'\n成功率: {success_rate}\n平均奖励: {avg_return}\n\n'
for r in range(env_max_reward + 1):
more_or_equal_r = (np.array(highest_rewards) >= r).sum()
more_or_equal_r_rate = more_or_equal_r / num_rollouts
summary_str += f'奖励 >= {r}: {more_or_equal_r}/{num_rollouts} = {more_or_equal_r_rate * 100}%\n'
print(summary_str)
# 保存成功率到文本文件
result_file_name = 'result_' + ckpt_name.split('.')[0] + '.txt'
with open(os.path.join(ckpt_dir, result_file_name), 'w') as f:
f.write(summary_str)
f.write(repr(episode_returns))
f.write('\n\n')
f.write(repr(highest_rewards))
return success_rate, avg_return
def forward_pass(data, policy):
image_data, qpos_data, action_data, is_pad = data
image_data, qpos_data, action_data, is_pad = image_data.cuda(), qpos_data.cuda(), action_data.cuda(), is_pad.cuda() # 移动数据到GPU
return policy(qpos_data, image_data, action_data, is_pad) # 执行前向传播
def train_bc(train_dataloader, val_dataloader, config):
num_epochs = config['num_epochs'] # 总训练轮数
ckpt_dir = config['ckpt_dir'] # 检查点目录
seed = config['seed'] # 随机种子
policy_class = config['policy_class'] # 策略类型
policy_config = config['policy_config'] # 策略配置
set_seed(seed) # 设置随机种子
policy = make_policy(policy_class, policy_config) # 创建策略
policy.cuda() # 移动策略到GPU
optimizer = make_optimizer(policy_class, policy) # 创建优化器
train_history = [] # 训练过程记录
validation_history = [] # 验证过程记录
min_val_loss = np.inf # 最小验证损失
best_ckpt_info = None # 最佳检查点信息
for epoch in tqdm(range(num_epochs)):
print(f'\n轮次 {epoch}')
# 验证
with torch.inference_mode():
policy.eval() # 切换到评估模式
epoch_dicts = []
for batch_idx, data in enumerate(val_dataloader):
forward_dict = forward_pass(data, policy) # 前向传播
epoch_dicts.append(forward_dict)
epoch_summary = compute_dict_mean(epoch_dicts) # 计算验证平均记录
validation_history.append(epoch_summary)
epoch_val_loss = epoch_summary['loss'] # 验证损失
if epoch_val_loss < min_val_loss:
min_val_loss = epoch_val_loss # 更新最小验证损失
best_ckpt_info = (epoch, min_val_loss, deepcopy(policy.state_dict())) # 保存最佳检查点信息
print(f'验证损失: {epoch_val_loss:.5f}')
summary_string = ''
for k, v in epoch_summary.items():
summary_string += f'{k}: {v.item():.3f} '
print(summary_string)
# 训练
policy.train() # 切换到训练模式
optimizer.zero_grad() # 清零优化器梯度
for batch_idx, data in enumerate(train_dataloader):
forward_dict = forward_pass(data, policy) # 前向传播
# 反向传播
loss = forward_dict['loss'] # 获取损失
loss.backward() # 反向传播
optimizer.step() # 更新参数
optimizer.zero_grad() # 清零梯度
train_history.append(detach_dict(forward_dict)) # 存储训练记录
epoch_summary = compute_dict_mean(train_history[(batch_idx + 1) * epoch:(batch_idx + 1) * (epoch + 1)]) # 计算训练平均记录
epoch_train_loss = epoch_summary['loss']
print(f'训练损失: {epoch_train_loss:.5f}')
summary_string = ''
for k, v in epoch_summary.items():
summary_string += f'{k}: {v.item():.3f} '
print(summary_string)
if epoch % 100 == 0: # 每 100 轮保存一次检查点
ckpt_path = os.path.join(ckpt_dir, f'policy_epoch_{epoch}_seed_{seed}.ckpt')
torch.save(policy.state_dict(), ckpt_path)
plot_history(train_history, validation_history, epoch, ckpt_dir, seed) # 绘制训练曲线
ckpt_path = os.path.join(ckpt_dir, f'policy_last.ckpt') # 保存最后的检查点
torch.save(policy.state_dict(), ckpt_path)
best_epoch, min_val_loss, best_state_dict = best_ckpt_info # 获取最佳检查点信息
ckpt_path = os.path.join(ckpt_dir, f'policy_epoch_{best_epoch}_seed_{seed}.ckpt') # 保存最佳检查点
torch.save(best_state_dict, ckpt_path)
print(f'训练完成:\n种子 {seed}, 验证损失 {min_val_loss:.6f} 在轮次 {best_epoch}')
# 保存训练曲线
plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed)
return best_ckpt_info
def plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed):
# 保存训练曲线
for key in train_history[0]:
plot_path = os.path.join(ckpt_dir, f'train_val_{key}_seed_{seed}.png') # 绘图路径
plt.figure()
train_values = [summary[key].item() for summary in train_history] # 训练值
val_values = [summary[key].item() for summary in validation_history] # 验证值
plt.plot(np.linspace(0, num_epochs - 1, len(train_history)), train_values, label='训练')
plt.plot(np.linspace(0, num_epochs - 1, len(validation_history)), val_values, label='验证')
plt.tight_layout()
plt.legend()
plt.title(key)
plt.savefig(plot_path) # 保存图像
print(f'图像已保存到 {ckpt_dir}')
if __name__ == '__main__':
parser = argparse.ArgumentParser() # 创建命令行参数解析器
parser.add_argument('--eval', action='store_true') # 评估标志
parser.add_argument('--onscreen_render', action='store_true') # 屏幕渲染标志
parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True) # 检查点目录
parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', required=True) # 策略类
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True) # 任务名称
parser.add_argument('--batch_size', action='store', type=int, help='batch_size', required=True) # 批量大小
parser.add_argument('--seed', action='store', type=int, help='seed', required=True) # 随机种子
parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=True) # 训练轮数
parser.add_argument('--lr', action='store', type=float, help='lr', required=True) # 学习率
# 对于 ACT 策略
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False) # KL 权重
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False) # 块大小
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False) # 隐藏维度
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False) # 前馈维度
parser.add_argument('--temporal_agg', action='store_true') # 时间聚合标志
main(vars(parser.parse_args())) # 解析参数并调用主函数
程序结构和功能
该程序的主要功能是使用模仿学习的方法训练一个策略模型(如 ACT 或 CNN-MLP),以执行特定的任务(如模拟抓取和放置立方体)。程序包括以下主要部分:
-
导入模块:程序导入了必要的库和自定义模块,提供了深度学习、数据处理、可视化等功能。
-
主函数 (
main):该函数是程序的核心,负责处理命令行参数,加载任务配置,准备数据,训练模型或进行评估。 -
策略创建 (
make_policy) 和 优化器创建 (make_optimizer):这些函数根据指定的策略类创建相应的策略对象和优化器。 -
数据处理:通过
load_data函数加载训练和验证数据集,并进行预处理。 -
训练和评估:
train_bc函数用于执行训练过程,包括前向传播、损失计算、反向传播和优化器更新。eval_bc函数用于在评估模式下运行训练好的策略,计算成功率和平均奖励。
-
可视化:通过
save_videos和plot_history函数保存训练过程中的视频和训练曲线图。
调整参数以优化训练过程和模型性能
-
学习率 (
--lr):- 学习率是优化过程中最重要的超参数之一。较高的学习率可能导致训练不稳定,而较低的学习率可能导致收敛缓慢。建议在训练过程中使用学习率调度器,动态调整学习率,以适应训练过程。
-
批量大小 (
--batch_size):- 批量大小直接影响训练过程的稳定性和速度。较小的批量可以提高泛化能力,但训练时间较长;较大的批量可以加快训练速度,但可能导致过拟合。建议在实验中调整,找到合适的平衡。
-
训练轮数 (
--num_epochs):- 训练轮数决定了模型学习的程度。可以通过观察训练和验证损失的变化,判断模型是否过拟合,并根据需要调整轮数。
-
KL 权重 (
--kl_weight):- 在模仿学习中,KL 散度用于衡量生成策略与专家策略之间的距离。调整 KL 权重可以影响策略的学习效果,建议根据验证集的表现进行微调。
-
网络结构参数(如
--hidden_dim,--dim_feedforward):- 调整网络的复杂性(如隐藏层维度和前馈维度)可以提高模型的表达能力,但会增加计算成本和过拟合的风险。可以根据数据集的复杂性和大小进行实验。
-
时间聚合 (
--temporal_agg):- 该参数可以帮助模型更好地处理时间序列数据,通过调整该参数可以改善模型在执行动态任务时的表现。
-
随机种子 (
--seed):- 使用不同的随机种子进行实验,可以评估模型的稳定性和泛化能力。确保实验的可重复性也是重要的。
总结
整个程序是一个较为完整的模仿学习训练框架,涉及从数据加载、模型训练到评估和可视化的一系列环节。通过合理调整参数,可以有效地优化训练过程和最终模型的性能。建议在训练过程中监控训练和验证损失,并根据表现进行参数调整,以获得最佳结果。
更多推荐
所有评论(0)