学习笔记:搞懂早停(Early Stopping)——模型训练的“刹车神器”

之前训练模型时,总想着“训练轮数越多,效果越好”,于是把Epochs设得很大(比如50轮、100轮),结果发现训练到后期,训练集的损失值越来越低、准确率越来越高,可验证集的损失值却开始飙升、准确率持续下降——这就是典型的过拟合,模型把训练集的“噪声”都当成了规律,泛化能力越来越差。直到后来用上了早停,才发现原来不用训满所有轮数,在合适的时机“踩刹车”,既能避免过拟合,又能节省大量训练时间,早停也成了我现在训练所有模型的“必备操作”。

一、通俗理解:早停到底是什么?

早停,顾名思义,就是“在模型训练还没完成预设轮数时,提前停止训练”——但它不是盲目停止,而是有明确的“刹车依据”,核心就是“盯着验证集的表现,一旦表现变差,就及时停手”。

我用备考刷题来打比方,特别容易理解:

  • 训练集就像“课后习题”,你反复刷题(训练模型),做题正确率(训练集准确率)越来越高,甚至能背下每道题的答案;
  • 验证集就像“模拟考试卷”,用来检验你真正的掌握程度(模型的泛化能力),而不是死记硬背的能力;
  • 早停就像“备考中的自我把控”:一开始你刷课后习题,模拟考试成绩也同步提升,可当你刷到第20套习题后,再参加模拟考试,成绩反而开始下降(因为你陷入了“死记硬背”,没真正理解知识点),这时候你就该停止刷更多习题,转而巩固已掌握的内容——早停做的就是这件事,当验证集表现不再提升甚至下滑时,立刻停止训练,保存此时(或验证集表现最优时)的模型。

再举我的实操案例:用BERT微调电商评论情感分析,预设Epochs=30,训练到第8轮时,验证集准确率达到峰值88.5%,第9轮开始,验证集准确率降到88.2%,第10轮降到87.8%,早停机制在第10轮后触发,直接停止训练,没有继续训完30轮——既避免了模型过拟合,又节省了20轮的训练时间(约1小时),最终保存的第8轮模型,泛化能力也是最好的。

二、早停的核心原理与目标(直击本质)

1. 核心目标:防止模型过拟合,保留最优泛化能力的模型

这是早停最核心的价值,没有之一。模型训练的过程,是一个“从欠拟合→拟合→过拟合”的演变过程:

  • 欠拟合阶段:训练集和验证集的表现都很差,模型还没学到数据的核心规律;
  • 拟合阶段:训练集表现稳步提升,验证集表现也同步提升,模型学到了数据的通用规律,这是最理想的状态;
  • 过拟合阶段:训练集表现持续提升(甚至接近100%),但验证集表现开始下滑,模型学到了训练集的专属噪声(比如个别样本的特殊表述、图片的无关像素),无法适配新数据。

早停的作用,就是“在模型进入过拟合阶段之前,或刚进入过拟合阶段时,及时刹车”,终止训练并保存验证集表现最优的模型,避免模型继续学习噪声,从而保留模型的最优泛化能力。

2. 核心原理:以“验证集指标”为监控依据,设置“耐心值”触发停止

早停的运行逻辑很简单,没有复杂的算法,核心就3个要素,一步一步拆解:

要素1:选择合适的“监控指标”

必须选验证集上的任务相关指标,而不是训练集指标,常见指标选择:

  • 分类任务(单标签/多标签):优先选准确率、F1值(尤其是类别不平衡时),也可以监控验证集损失值;
  • 回归任务(比如房价预测、销量预测):优先选MSE(均方误差)、MAE(平均绝对误差),或验证集损失值;
  • 我的实操经验:分类任务优先监控F1值,比准确率更稳定,不容易受个别样本影响;比如商品多标签分类,用F1值作为监控指标,早停触发更精准,而用准确率容易出现“指标波动大”的问题,导致早停时机判断失误。
要素2:设置“耐心值”(Patience,也叫容忍度)

耐心值就是“允许验证集指标连续多少轮不提升(甚至下滑),才触发早停”,比如Patience=5,意味着:

  • 当验证集指标达到一个峰值后,接下来连续5轮训练,指标都没有超过这个峰值(甚至持续下滑),就触发早停;
  • 耐心值的作用是“过滤指标的正常波动”,避免因为一两轮的指标小幅下滑,就误触发早停(毕竟模型训练过程中,验证集指标有小幅波动是正常的);
  • 我的实操经验:小白入门优先设Patience=510,任务复杂(比如大模型文本生成)可以设到1015,不要太小(比如Patience=2),容易导致模型还没训练充分就提前停止(欠拟合),也不要太大(比如Patience=20),会浪费训练时间,也可能让模型陷入过拟合。
要素3:保存“最优模型”(Best Model)

早停不是“停在哪轮就保存哪轮模型”,而是在训练过程中,持续监控验证集指标,一旦出现更优的指标(比如准确率更高、损失值更低),就自动覆盖保存当前模型,最终早停触发时,我们手里保存的是“整个训练过程中,验证集表现最好的那轮模型”,而不是最后一轮表现下滑的模型。

  • 我的实操感悟:这是早停最容易被忽略的关键细节,我第一次用早停时,只设置了停止触发,没保存最优模型,结果停在第10轮,保存的是第10轮的模型(准确率87.8%),而不是第8轮的最优模型(准确率88.5%),白白浪费了最优结果,后来每次用早停,都会加上“最优模型保存”的步骤,再也没踩过这个坑。

三、早停的实操流程(以PyTorch为例,小白可直接套用)

早停在实操中没有内置的“一键调用”(PyTorch需自定义简单类,TensorFlow有内置EarlyStopping),但逻辑简单,我整理了自己常用的实操流程,小白直接复制修改就能用:

步骤1:自定义早停类(核心代码,简单易懂)

import torch
import os

class EarlyStopping:
    def __init__(self, patience=5, verbose=False, delta=0, save_path='best_model.pth'):
        # 耐心值:允许连续多少轮不提升
        self.patience = patience
        # 是否打印日志
        self.verbose = verbose
        # 指标提升的最小阈值(避免微小波动被判定为提升)
        self.delta = delta
        # 最优模型保存路径
        self.save_path = save_path
        
        # 初始化内部变量
        self.counter = 0  # 连续不提升轮数计数器
        self.best_score = None  # 最优指标值
        self.early_stop = False  # 是否触发早停
        self.val_loss_min = float('inf')  # 初始验证集损失设为无穷大

    def __call__(self, val_metric, model):
        # 这里val_metric可以是损失值(越小越好),也可以是准确率/F1值(越大越好)
        # 我这里以“验证集损失值”为例(越小越好)
        score = -val_metric  # 损失值越小,score越大,方便统一判断

        if self.best_score is None:
            # 第一轮训练,直接保存模型
            self.best_score = score
            self.save_checkpoint(val_metric, model)
        elif score < self.best_score + self.delta:
            # 验证集指标没有提升,计数器+1
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                # 计数器达到耐心值,触发早停
                self.early_stop = True
        else:
            # 验证集指标提升,更新最优分数,保存模型,重置计数器
            self.best_score = score
            self.save_checkpoint(val_metric, model)
            self.counter = 0

    def save_checkpoint(self, val_metric, model):
        # 保存最优模型
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_metric:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.save_path)
        self.val_loss_min = val_metric

步骤2:训练过程中调用早停

# 1. 初始化早停(耐心值设为5,开启日志,保存最优模型)
early_stopping = EarlyStopping(patience=5, verbose=True, save_path='bert_sentiment_best.pth')

# 2. 开始训练循环
epochs = 30
for epoch in range(epochs):
    # 训练步骤(省略,正常训练模型,得到train_loss/train_acc)
    model.train()
    # ... 训练代码 ...
    
    # 验证步骤(关键:得到验证集指标val_loss/val_acc)
    model.eval()
    val_loss = 0.0
    # ... 验证代码,计算val_loss ...
    
    # 3. 调用早停,传入验证集损失值和模型
    early_stopping(val_loss, model)
    
    # 4. 判断是否触发早停,若是则终止训练
    if early_stopping.early_stop:
        print("Early stopping triggered!")
        break

# 5. 训练结束后,加载最优模型
model.load_state_dict(torch.load('bert_sentiment_best.pth'))

步骤3:关键参数调整(小白必看)

  • patience:根据任务复杂度调整,515为宜,简单任务(比如简单图像分类)设5,复杂任务(比如大模型微调)设1015;
  • delta:指标提升的最小阈值,比如设0.0001,避免“验证集损失从0.1000降到0.0999”这种微小波动被判定为“指标提升”,减少模型频繁保存的开销;
  • save_path:设置清晰的模型保存路径,最好带任务名称和时间戳,比如bert_sentiment_best_20251228.pth,方便后续查找。

四、早停的关键注意事项与避坑总结(亲测踩过的6个坑)

  1. 必须用“独立的验证集”,不能用训练集或测试集:验证集必须和训练集无交集,且数据分布一致,否则监控的指标没有参考意义,早停时机判断会完全失误;我第一次做多标签分类时,偷懒用训练集的子集做验证集,结果早停触发时,模型在测试集上的表现一塌糊涂,换成独立验证集后,问题立刻解决。
  2. 监控指标要“贴合任务需求”,不能只盯损失值:分类任务中,验证集损失值可能小幅波动,但F1值持续提升,这时候不能触发早停;反之,损失值下降,但F1值下滑,说明模型过拟合,应该触发早停;我的经验是“分类任务优先盯F1值/准确率,回归任务优先盯MSE/MAE,损失值作为辅助参考”。
  3. 耐心值设置要“适中”,避免欠拟合或过拟合:
    • 耐心值太小(比如Patience=2):容易被指标的正常波动误导,模型还没训练充分就提前停止,导致欠拟合;
    • 耐心值太大(比如Patience=20):会浪费大量训练时间,也会让模型进入深度过拟合阶段,即使保存了最优模型,泛化能力也会受影响;
    • 我的实操参考:小白入门先设Patience=5,跑通后根据指标波动情况调整,若指标波动大,适当调大到8~10。
  4. 不要和“批量归一化(BN)”“学习率衰减”冲突:早停和BN、学习率衰减是可以搭配使用的(都是防过拟合的方法),但要注意顺序:先做学习率衰减,再进行验证,最后调用早停;我之前把早停放在学习率衰减之前,导致早停触发时,学习率还没调整,模型表现不是最优的,调整顺序后,模型效果提升了2%。
  5. 保存模型时要“完整”,不能漏存关键参数:如果模型有自定义层或优化器参数,要注意保存完整的模型状态,而不仅仅是模型权重;比如用torch.save(model.state_dict(), ...)只保存权重,若需要保存优化器状态,可用torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, ...)
  6. 早停是“辅助防过拟合”,不是“万能解药”:早停只能防止“训练轮数过多导致的过拟合”,无法解决“数据量不足、数据分布不均、模型参数量过大”等问题;比如用BERT-large训练500条样本,即使用上早停,模型还是会过拟合,这时候需要先扩充数据量或用轻量模型(比如DistilBERT),再搭配早停,才能达到理想效果。

五、早停与其他防过拟合方法的区别(清晰对比)

为了避免混淆,我把早停和常见的防过拟合方法做了对比,结合自己的实操体验,一目了然:

防过拟合方法 核心逻辑 与早停的区别 实操搭配建议
早停(Early Stopping) 终止训练,保留验证集最优模型,避免学习噪声 不修改模型结构和参数,仅终止训练过程 可与任何防过拟合方法搭配,是“基础必备”
Dropout 训练时随机丢弃部分神经元,防止模型依赖特定神经元 修改模型结构,主动减少模型的学习能力 早停+Dropout,效果1+1>2,是大模型微调的黄金组合
权重衰减(L2正则) 给模型权重添加惩罚项,防止权重过大,避免过拟合 修改损失函数,主动约束模型权重 早停+权重衰减,适合回归任务和类别不平衡任务
数据增强 扩充训练数据量,增加数据多样性,让模型学到更通用的规律 从数据层面解决过拟合,不涉及模型和训练过程 早停+数据增强,是解决“数据量不足”导致过拟合的最优方案

六、核心感悟

其实早停是一个“简单却极其有效”的训练技巧,它没有复杂的原理,也不用修改模型结构,仅仅通过“监控验证集指标、及时终止训练”,就能有效防止过拟合,节省训练时间。我当初觉得它“不起眼”,直到踩了“训练轮数过多导致过拟合”的大坑,才意识到它的重要性——现在我训练任何模型,都会默认加上早停,它就像一个“智能刹车”,帮我守住模型泛化能力的底线。

对于小白来说,不用一开始就纠结早停的实现细节,重点是先理解它的核心逻辑(“盯验证集、设耐心、存最优”),然后套用现成的早停代码,先跑通一个简单任务(比如MNIST手写数字识别、简单文本分类),再根据自己的任务调整参数,慢慢就能掌握它的精髓。

早停的核心价值,不是“让模型训练得更快”,而是“让模型保留最好的泛化能力”——在AI模型训练中,“不是训得越久越好,而是训得恰到好处最好”,这也是早停给我最深刻的启示。现在再训练模型,我再也不会盲目追求“训满所有轮数”,而是学会了“见好就收”,这也让我的模型项目成功率提升了不少

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐