深度学习必备:使用 Optuna 自动化搜索最优模型参数
Optuna 是一个专为机器学习设计的自动超参数优化软件框架,它通过系统化的搜索策略,在给定的超参数空间内寻找最优配置,以提升模型性能。其核心原理基于高效的采样算法和智能的剪枝策略,旨在尽可能减少不必要的试验,快速定位到较优的超参数组合。用户需要在目标函数中明确每个超参数的搜索范围和类型。对于连续型超参数,如学习率,通常使用方法,若希望在对数尺度上进行搜索,可设置log=True,例如,这意味着学
一、Optuna 简介
Optuna 是一个专为机器学习设计的自动超参数优化软件框架,它通过系统化的搜索策略,在给定的超参数空间内寻找最优配置,以提升模型性能。其核心原理基于高效的采样算法和智能的剪枝策略,旨在尽可能减少不必要的试验,快速定位到较优的超参数组合。
在搜索最优超参数的过程中,Optuna 主要通过以下步骤来实现:
1. 定义搜索空间
用户需要在目标函数中明确每个超参数的搜索范围和类型。对于连续型超参数,如学习率,通常使用 suggest_float 方法,若希望在对数尺度上进行搜索,可设置 log=True,例如 lr = trial.suggest_float("learning_rate", 1e - 5, 1e - 3, log=True),这意味着学习率将在 10−510^{-5}10−5 到 10−310^{-3}10−3 的对数空间内被采样。对于离散型超参数,若是在一个范围内取值,如神经网络隐藏层维度,可使用 suggest_int 方法,如 hidden_dim = trial.suggest_int("hidden_dim", 128, 512);若是枚举类型,像批量大小的选择 [16, 32, 64],则采用 suggest_categorical 方法,即 batch_size = trial.suggest_categorical("batch_size", [16, 32, 64]) 。通过这种方式,Optuna 能够了解每个超参数可能的取值情况,构建起完整的超参数搜索空间。
2. 采样超参数组合
Optuna 基于特定的采样算法,从定义好的搜索空间中随机抽取超参数组合,生成一系列试验。每一次试验都代表着一个特定的超参数配置,Optuna 会运行目标函数(通常是模型训练和评估过程)来评估该配置下模型的性能,例如计算验证集损失。这些采样并非完全随机,而是利用算法尽量覆盖搜索空间的同时,倾向于探索更有潜力的区域,以提高找到最优解的效率。
3. 模型训练与评估
在每次试验中,Optuna 使用采样得到的超参数组合来初始化模型、优化器等组件,并进行模型训练。以常见的深度学习模型训练为例,在训练过程中,模型在训练数据集上进行前向传播和反向传播,不断更新模型参数以最小化损失函数。同时,在验证数据集上评估模型性能,计算如验证集损失、准确率等指标。这些性能指标作为反馈,用于指导后续的超参数搜索。例如在目标函数中,通过 trial.report(avg_val_loss, epoch) 语句向 Optuna 报告每个 epoch 的验证集损失,以便 Optuna 根据这些中间结果做出决策。由于完整训练模型需要的时间太长,所以一般都只训练 10 - 50 轮,就进行对验证集进行验证操作,这样做虽然可能会错过“慢热型”但最终效果更好的参数组合,但是能极大的提升参数寻找的效率,是一种折中方案。
4. 剪枝策略
为了避免在明显不佳的超参数配置上浪费过多资源(如计算时间和计算资源),Optuna 引入了剪枝策略。以 MedianPruner 为例,它会对比中间轮次的试验结果,若某个试验在特定轮次的性能(如验证集损失)明显差于其他试验在相同轮次的中位数性能,那么 Optuna 会认为该试验不太可能产生最优超参数,从而提前终止该试验,即进行剪枝操作。例如在目标函数中,通过 if trial.should_prune(): raise optuna.exceptions.TrialPruned() 代码实现剪枝逻辑。这一策略大大减少了不必要的计算开销,使得 Optuna 能够将资源集中在更有希望的超参数组合上进行探索。
5. 迭代搜索与结果分析
Optuna 会不断重复上述采样、训练评估和剪枝的过程,进行多轮试验。随着试验的进行,Optuna 会逐渐积累不同超参数组合及其对应的模型性能数据。通过分析这些数据,Optuna 能够逐渐了解哪些区域的超参数组合更有可能产生较好的模型性能,从而更加智能地进行后续的采样。当达到预设的试验次数或满足其他终止条件时,Optuna 停止搜索,并从所有试验中找出性能最优的超参数组合,将其作为最终推荐的最优超参数配置。同时,Optuna 还提供了丰富的可视化工具,如 optuna.visualization.plot_param_importances(study) 用于可视化参数重要性,帮助用户理解不同超参数对模型性能的影响程度,进一步优化超参数搜索策略。
二、自动化超参数搜索的代码实现思路
1. 导入 optuna(没有安装的话就执行“pip install optuna”)
import optuna
# 定义搜索常量
NUM_EPOCHS = 5 # 每个试验的训练轮次(缩减以加速搜索)
NUM_TRIALS = 10 # 总试验次数
RESULT_DIR = "./optuna_results"
os.makedirs(RESULT_DIR, exist_ok=True)
2. 构建目标函数:Optuna 优化的核心逻辑
def objective(trial):
"""定义Optuna优化目标:最小化验证集损失"""
# 1. 超参数采样(定义搜索空间)
lr = trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True)
batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
dropout = trial.suggest_float("dropout", 0.1, 0.5)
weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-2, log=True)
# 2. 设备与数据加载
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = YourDataset(CONFIG.DATA_PATH)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=4
)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=4)
# 3. 模型与优化器初始化
model = YourModel(
input_dim=CONFIG.INPUT_DIM,
hidden_dim=trial.suggest_int("hidden_dim", 128, 512),
dropout=dropout
).to(device)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=lr,
weight_decay=weight_decay
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=NUM_EPOCHS
)
# 4. 训练与验证循环
best_val_loss = float("inf")
for epoch in range(NUM_EPOCHS):
# 训练阶段
model.train()
train_losses = []
for batch in train_loader:
# 前向传播与损失计算
outputs = model(batch["input"].to(device))
loss = torch.nn.functional.mse_loss(outputs, batch["target"].to(device))
# 反向传播与优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_losses.append(loss.item())
# 验证阶段
model.eval()
val_losses = []
with torch.no_grad():
for batch in val_loader:
outputs = model(batch["input"].to(device))
loss = torch.nn.functional.mse_loss(outputs, batch["target"].to(device))
val_losses.append(loss.item())
avg_val_loss = np.mean(val_losses)
print(f"Epoch {epoch+1} - Val Loss: {avg_val_loss:.4f}")
# 5. 结果报告与模型保存
trial.report(avg_val_loss, epoch) # 向Optuna报告中间结果
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
torch.save(
{
"model_state_dict": model.state_dict(),
"hyperparameters": trial.params,
"val_loss": best_val_loss
},
f"{RESULT_DIR}/trial_{trial.number}.pth"
)
# 6. 剪枝策略:若性能无提升则提前终止
if trial.should_prune():
raise optuna.exceptions.TrialPruned()
return best_val_loss
3. 主函数:启动优化流程并处理结果
def main():
# 创建Optuna研究对象
timestamp = datetime.now().strftime("%Y%m%d_%H%M")
study_name = f"model_hyperparam_search_{timestamp}"
storage_path = f"sqlite:///{RESULT_DIR}/{study_name}.db"
study = optuna.create_study(
study_name=study_name,
storage=storage_path,
direction="minimize", # 目标:最小化损失
pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=2)
)
# 执行超参数优化
study.optimize(objective, n_trials=NUM_TRIALS)
# 输出最优结果
best_trial = study.best_trial
print(f"\n最优超参数: {best_trial.params}")
print(f"最优验证损失: {best_trial.value:.4f}")
# 保存最优模型到统一路径
best_model_path = f"{RESULT_DIR}/best_model.pth"
trial_model_path = f"{RESULT_DIR}/trial_{best_trial.number}.pth"
if os.path.exists(trial_model_path):
best_model_data = torch.load(trial_model_path)
torch.save(best_model_data, best_model_path)
print(f"最优模型已保存至: {best_model_path}")
# 保存超参数配置为JSON
with open(f"{RESULT_DIR}/best_params_{timestamp}.json", "w") as f:
params = {
**best_trial.params,
"best_loss": best_trial.value,
"timestamp": timestamp
}
json.dump(params, f, indent=2)
if __name__ == "__main__":
main()
三、进阶技巧
1. 超参数搜索空间设计
- 连续参数:
suggest_float(对数尺度用log=True) - 离散参数:
suggest_int(范围搜索)或suggest_categorical(枚举搜索) - 示例:学习率常用对数空间
1e-5到1e-3,批量大小用枚举[16, 32, 64]
2. 剪枝策略优化
MedianPruner:对比中间轮次结果,低于中位数则剪枝HyperbandPruner:基于资源分配的自适应剪枝,适合长周期训练
3. 分布式与持久化
- 通过 SQLite 存储研究状态:
storage="sqlite:///study.db" - 多节点共享优化:配置数据库存储(如 PostgreSQL)
4. 可视化与分析
- 使用 Optuna 内置可视化:
optuna.visualization.plot_param_importances(study) - TensorBoard 记录训练过程:结合
SummaryWriter跟踪损失曲线
以上只给出常用的进阶技巧的名称和作用,具体使用时可以上网查阅或者询问 AI
下面给出我在查找超参数时的完整示例代码:
import os
import math
import optuna
import torch
import torch.nn.functional as F
import numpy as np
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm
import json
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
# Import from local modules
import config
from dataset import StateDataset, load_or_create_stats
from model import ConditionalUnet1D
# Constants for hyperparameter search
NUM_EPOCHS_SEARCH = 10 # Reduced epochs for faster search
NUM_TRIALS = 20 # Number of parameter combinations to try
SEARCH_RESULT_DIR = "./hyperparameter_search_results"
BEST_MODEL_PATH = "./best_model.pth"
# Create results directory if it doesn't exist
os.makedirs(SEARCH_RESULT_DIR, exist_ok=True)
def objective(trial):
"""Optuna objective function to minimize validation loss."""
# Sample hyperparameters
prior_pi = trial.suggest_float("prior_pi", 0.1, 0.5)
log_sigma1 = trial.suggest_float("log_sigma1", -3.0, 0.0) # exp(-3) to exp(0)
log_sigma2 = trial.suggest_float("log_sigma2", -10.0, -5.0) # exp(-10) to exp(-5)
kl_beta = trial.suggest_float("kl_beta", 1e-7, 1e-3, log=True)
# Convert log values to actual sigma values
prior_sigma1 = math.exp(log_sigma1)
prior_sigma2 = math.exp(log_sigma2)
# Log the sampled parameters
print(f"\nTrial {trial.number}:")
print(f" prior_pi: {prior_pi:.4f}")
print(f" prior_sigma1: {prior_sigma1:.6f} (log: {log_sigma1:.2f})")
print(f" prior_sigma2: {prior_sigma2:.6f} (log: {log_sigma2:.2f})")
print(f" kl_beta: {kl_beta:.8f}")
# Setup device
device = torch.device(config.DEVICE)
# Load dataset
stats = load_or_create_stats(config.DATASET_PATH, config.STATS_PATH)
# Create dataset and split into train/val
full_dataset = StateDataset(
dataset_path=config.DATASET_PATH,
pred_horizon=config.PRED_HORIZON,
obs_horizon=config.OBS_HORIZON,
action_horizon=config.ACTION_HORIZON,
stats=stats
)
# Split dataset: 90% train, 10% validation
train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
full_dataset, [train_size, val_size]
)
# Create dataloaders
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config.BATCH_SIZE,
num_workers=config.NUM_WORKERS,
shuffle=True,
pin_memory=True,
persistent_workers=True if config.NUM_WORKERS > 0 else False
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=config.BATCH_SIZE,
num_workers=config.NUM_WORKERS,
shuffle=False,
pin_memory=True,
persistent_workers=True if config.NUM_WORKERS > 0 else False
)
# Initialize model with the sampled hyperparameters
noise_pred_net = ConditionalUnet1D(
input_dim=config.MODEL_INPUT_DIM,
global_cond_dim=config.MODEL_GLOBAL_COND_DIM,
diffusion_step_embed_dim=config.DIFFUSION_STEP_EMBED_DIM,
down_dims=config.DOWNSAMPLE_DIMS,
kernel_size=config.MODEL_KERNEL_SIZE,
n_groups=config.MODEL_N_GROUPS,
prior_pi=prior_pi,
prior_sigma1=prior_sigma1,
prior_sigma2=prior_sigma2
).to(device)
# Setup noise scheduler
noise_scheduler = DDPMScheduler(
num_train_timesteps=config.NUM_DIFFUSION_TRAIN_STEPS,
beta_schedule=config.BETA_SCHEDULE,
clip_sample=True,
prediction_type='epsilon'
)
# Setup optimizer and scheduler
optimizer = torch.optim.AdamW(
params=noise_pred_net.parameters(),
lr=config.LEARNING_RATE,
weight_decay=config.WEIGHT_DECAY
)
lr_scheduler = get_scheduler(
name='cosine',
optimizer=optimizer,
num_warmup_steps=100, # Reduced for faster search
num_training_steps=len(train_loader) * NUM_EPOCHS_SEARCH
)
# Setup tensorboard
trial_name = f"trial_{trial.number}_pi{prior_pi:.2f}_s1{log_sigma1:.1f}_s2{log_sigma2:.1f}_kl{kl_beta:.1e}"
writer = SummaryWriter(f"{config.TB_LOG_DIR}/search/{trial_name}")
# Training loop
best_val_loss = float('inf')
for epoch in range(NUM_EPOCHS_SEARCH):
# Training phase
noise_pred_net.train()
train_losses = []
for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS_SEARCH} (Train)", leave=False):
nobs = batch['obs'].to(device)
naction = batch['action'].to(device)
B = nobs.shape[0]
obs_cond = nobs.flatten(start_dim=1)
noise = torch.randn(naction.shape, device=device)
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps,
(B,), device=device
).long()
noisy_actions = noise_scheduler.add_noise(naction, noise, timesteps)
noise_pred, kl_div = noise_pred_net(
noisy_actions, timesteps, global_cond=obs_cond
)
recon_loss = F.mse_loss(noise_pred, noise)
loss = recon_loss + kl_beta * kl_div / B
loss.backward()
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
train_losses.append(loss.item())
avg_train_loss = np.mean(train_losses)
writer.add_scalar('Loss/train', avg_train_loss, epoch)
# Validation phase
noise_pred_net.eval()
val_losses = []
val_recon_losses = []
val_kl_losses = []
with torch.no_grad():
for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS_SEARCH} (Val)", leave=False):
nobs = batch['obs'].to(device)
naction = batch['action'].to(device)
B = nobs.shape[0]
obs_cond = nobs.flatten(start_dim=1)
noise = torch.randn(naction.shape, device=device)
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps,
(B,), device=device
).long()
noisy_actions = noise_scheduler.add_noise(naction, noise, timesteps)
noise_pred, kl_div = noise_pred_net(
noisy_actions, timesteps, global_cond=obs_cond
)
recon_loss = F.mse_loss(noise_pred, noise)
kl_loss = kl_beta * kl_div / B
loss = recon_loss + kl_loss
val_losses.append(loss.item())
val_recon_losses.append(recon_loss.item())
val_kl_losses.append(kl_loss.item())
avg_val_loss = np.mean(val_losses)
avg_val_recon_loss = np.mean(val_recon_losses)
avg_val_kl_loss = np.mean(val_kl_losses)
writer.add_scalar('Loss/val', avg_val_loss, epoch)
writer.add_scalar('Loss/val_recon', avg_val_recon_loss, epoch)
writer.add_scalar('Loss/val_kl', avg_val_kl_loss, epoch)
print(f"Epoch {epoch+1}/{NUM_EPOCHS_SEARCH} - "
f"Train Loss: {avg_train_loss:.6f}, "
f"Val Loss: {avg_val_loss:.6f}, "
f"Val Recon: {avg_val_recon_loss:.6f}, "
f"Val KL: {avg_val_kl_loss:.6f}")
# Report intermediate value to Optuna
trial.report(avg_val_loss, epoch)
# Save best model
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
torch.save({
'model_state_dict': noise_pred_net.state_dict(),
'hyperparameters': {
'prior_pi': prior_pi,
'prior_sigma1': prior_sigma1,
'prior_sigma2': prior_sigma2,
'kl_beta': kl_beta
},
'val_loss': best_val_loss
}, f"{SEARCH_RESULT_DIR}/model_trial_{trial.number}.pth")
# Handle pruning based on the intermediate value
if trial.should_prune():
writer.close()
raise optuna.exceptions.TrialPruned()
writer.close()
return best_val_loss
def main():
# Create study
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
study_name = f"bayes_prior_search_{timestamp}"
storage_name = f"sqlite:///{SEARCH_RESULT_DIR}/{study_name}.db"
study = optuna.create_study(
study_name=study_name,
storage=storage_name,
direction="minimize",
pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=5)
)
# Run optimization
study.optimize(objective, n_trials=NUM_TRIALS)
# Print results
print("\nStudy statistics: ")
print(f" Number of finished trials: {len(study.trials)}")
print(f" Best trial:")
trial = study.best_trial
print(f" Value: {trial.value}")
print(f" Params: ")
for key, value in trial.params.items():
print(f" {key}: {value}")
# Calculate actual values from log values
best_prior_pi = trial.params["prior_pi"]
best_prior_sigma1 = math.exp(trial.params["log_sigma1"])
best_prior_sigma2 = math.exp(trial.params["log_sigma2"])
best_kl_beta = trial.params["kl_beta"]
print(f"\nBest hyperparameters:")
print(f" prior_pi: {best_prior_pi}")
print(f" prior_sigma1: {best_prior_sigma1} (log: {trial.params['log_sigma1']})")
print(f" prior_sigma2: {best_prior_sigma2} (log: {trial.params['log_sigma2']})")
print(f" kl_beta: {best_kl_beta}")
# Save best hyperparameters
best_params = {
"prior_pi": best_prior_pi,
"prior_sigma1": best_prior_sigma1,
"prior_sigma2": best_prior_sigma2,
"log_sigma1": trial.params["log_sigma1"],
"log_sigma2": trial.params["log_sigma2"],
"kl_beta": best_kl_beta,
"best_val_loss": trial.value,
"timestamp": timestamp
}
with open(f"{SEARCH_RESULT_DIR}/best_params_{timestamp}.json", "w") as f:
json.dump(best_params, f, indent=2)
# Load and save the best model
best_model_path = f"{SEARCH_RESULT_DIR}/model_trial_{trial.number}.pth"
if os.path.exists(best_model_path):
best_model_data = torch.load(best_model_path)
torch.save(best_model_data, BEST_MODEL_PATH)
print(f"Best model saved to {BEST_MODEL_PATH}")
if __name__ == "__main__":
main()
更多推荐
所有评论(0)