train_test_split 中,stratify 参数的核心功能是保持数据划分后训练集和测试集的类别分布与原数据集一致。这在类别不平衡(Class Imbalance)的场景下尤为重要,例如手写数字识别(MNIST)中某些数字的样本可能较少(如数字 1 较多,数字 5 较少)。使用 stratify 可以确保划分后的子集保留原始数据的类别比例,避免模型因训练集或测试集分布偏差而表现异常。


1. stratify 的作用

  • 解决的问题
    当数据集中某些类别样本较少时,随机划分可能导致训练集或测试集中某些类别样本过少甚至缺失(例如测试集中缺少某个数字)。
  • 功能原理
    根据指定的标签列(y),按原数据中各类别的比例分层抽样,确保训练集和测试集的类别分布与原数据一致。

2. 使用方法

(1) 代码示例

以手写数字识别(MNIST)数据集为例:

import pandas as pd
from sklearn.model_selection import train_test_split

# 读取数据(假设标签列名为 'label')
df = pd.read_csv("mnist.csv")
X = df.iloc[:, 1:].values  # 特征(像素值)
y = df.iloc[:, 0].values   # 标签(数字0-9)

# 使用 stratify 参数分层划分
X_train, X_test, y_train, y_test = train_test_split(
    X, y, 
    test_size=0.2, 
    random_state=42,
    stratify=y  # 关键参数:按标签 y 的分布分层抽样
)
(2) 验证分布一致性

通过统计各类别比例,确认划分后的分布是否一致:

import numpy as np

# 统计原始数据、训练集、测试集的类别比例
original_dist = np.bincount(y) / len(y)
train_dist = np.bincount(y_train) / len(y_train)
test_dist = np.bincount(y_test) / len(y_test)

print("原始分布:", original_dist)
print("训练集分布:", train_dist)
print("测试集分布:", test_dist)

3. 适用场景

  • 多分类任务:如手写数字识别(0-9共10类)。
  • 二分类任务:如癌症检测(阳性/阴性样本不平衡)。
  • 小样本类别:某些类别样本极少(如罕见病诊断数据)。

4. 注意事项

(1) 必须基于标签列

stratify 参数需要传入标签 y,而非特征 X。算法会根据 y 的类别分布进行分层抽样。

(2) 类别样本数不足时的处理

如果某个类别的样本数过少(例如某类仅有1个样本),stratify 可能无法严格分层(因无法拆分为训练集和测试集),此时会抛出警告或错误。解决方法:

  • 过采样(Oversampling):使用 SMOTE 等方法增加少数类样本。
  • 调整分层策略:合并小类别或调整 test_size
(3) 与交叉验证的结合

在交叉验证中,使用 StratifiedKFold 替代常规 KFold,确保每折数据的分布一致:

from sklearn.model_selection import StratifiedKFold

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
for train_idx, test_idx in skf.split(X, y):  # 需传入 y
    X_train, X_test = X[train_idx], X[test_idx]
    y_train, y_test = y[train_idx], y[test_idx]

5. 对比实验

(1) 不使用 stratify 的随机划分
  • 风险:测试集可能缺少某些类别,导致模型无法充分学习或评估。
    # 随机划分(可能分布不均)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )
    
(2) 使用 stratify 的分层划分
  • 优势:保证训练集和测试集的类别分布与原数据一致,提升模型泛化能力。

总结

  • 何时使用 stratify:当数据存在类别不平衡,且需要保持训练集/测试集分布一致时。
  • 如何传参:在 train_test_split 中设置 stratify=yy 为标签列)。
  • 避坑指南:处理极小样本类别时需谨慎,结合过采样或调整数据划分策略。
Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐