【机器学习案例-38】Kaggle案例之小麦品种分类:从数据探索到模型部署的完整流程
数据集包含三种不同小麦品种(Kama、Rosa、Canadian)的物理属性数据,每种70个样本。每个小麦粒都测量了7个几何参数:4.2 特征工程策略交互特征:捕捉特征间的相互关系多项式特征:增强模型的非线性拟合能力特征选择:基于统计显著性筛选重要特征4.3 模型评估方法混淆矩阵:直观展示分类结果ROC曲线:评估各类别的分类性能交叉验证:减少过拟合风险多指标评估:全面评价模型性能五、运行结果与性能
🧑 博主简介:曾任某智慧城市类企业
算法总监,目前在美国市场的物流公司从事高级算法工程师一职,深耕人工智能领域,精通python数据挖掘、可视化、机器学习等,发表过AI相关的专利并多次在AI类比赛中获奖。CSDN人工智能领域的优质创作者,提供AI相关的技术咨询、项目开发和个性化解决方案等服务,如有需要请站内私信或者联系任意文章底部的的VX名片(ID:xf982831907)
💬 博主粉丝群介绍:① 群内初中生、高中生、本科生、研究生、博士生遍布,可互相学习,交流困惑。② 热榜top10的常客也在群里,也有数不清的万粉大佬,可以交流写作技巧,上榜经验,涨粉秘籍。③ 群内也有职场精英,大厂大佬,可交流技术、面试、找工作的经验。④ 进群免费赠送写作秘籍一份,助你由写作小白晋升为创作大佬。⑤ 进群赠送CSDN评论防封脚本,送真活跃粉丝,助你提升文章热度。有兴趣的加文末联系方式,备注自己的CSDN昵称,拉你进群,互相学习共同进步。

【机器学习案例-38】Kaggle案例之小麦品种分类:从数据探索到模型部署的完整流程
本文完整实现了一个小麦品种分类的机器学习项目,基于物理属性特征(面积、周长、长度等)对三种小麦品种(Kama、Rosa、Canadian)进行分类。文章包含完整的数据探索、特征工程、逻辑回归模型训练、性能评估和模型部署的全流程代码。
一、项目背景与目标
1.1 数据集介绍
数据集包含三种不同小麦品种(Kama、Rosa、Canadian)的物理属性数据,每种70个样本。每个小麦粒都测量了7个几何参数:
- 面积 (area): 小麦粒的横截面积
- 周长 (perimeter): 小麦粒的周长
- 紧凑度 (compactness): 4π × 面积 / 周长²
- 长度 (length): 小麦粒的长度
- 宽度 (width): 小麦粒的宽度
- 不对称系数 (asymmetry coefficient): 形状的不对称性
- 沟长 (groove length): 小麦粒沟的长度
1.2 项目目标
- 探索小麦品种数据的分布特征和类别差异
- 构建有效的特征工程流程
- 训练逻辑回归分类器识别小麦品种
- 评估模型性能并进行模型部署
二、完整实现代码
2.1 环境准备与数据加载
# ==============================================
# 1. 导入必要的库
# ==============================================
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import warnings
warnings.filterwarnings('ignore')
# 机器学习库
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (accuracy_score, precision_score, recall_score,
f1_score, confusion_matrix, classification_report,
roc_auc_score, roc_curve)
from sklearn.feature_selection import SelectKBest, f_classif, RFE
# 设置中文字体和图表样式
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
sns.set_style("whitegrid")
# 设置随机种子
np.random.seed(42)
# ==============================================
# 2. 数据加载与初步探索
# ==============================================
print("="*60)
print("小麦品种分类数据集加载")
print("="*60)
# 加载数据集
# 注意:请确保wheat.csv文件在当前目录或正确路径下
df = pd.read_csv("wheat.csv")
print(f"数据集形状: {df.shape}")
print(f"行数: {df.shape[0]}, 列数: {df.shape[1]}")
# 显示数据基本信息
print("\n数据集信息:")
print(df.info())
print("\n前5行数据:")
print(df.head())
print("\n数据描述性统计:")
print(df.describe().round(4))
# 检查数据质量
print("\n数据质量检查:")
print(f"缺失值数量:\n{df.isnull().sum()}")
print(f"\n重复值数量: {df.duplicated().sum()}")
print(f"\n数据类型:\n{df.dtypes}")

2.2 数据可视化分析
# ==============================================
# 3. 数据可视化分析
# ==============================================
%matplotlib inline
def perform_eda_analysis(df):
"""执行探索性数据分析"""
print("="*60)
print("探索性数据分析(EDA)")
print("="*60)
# 创建综合可视化图表
fig = plt.figure(figsize=(20, 15))
# 1. 特征分布直方图
ax1 = plt.subplot(3, 3, 1)
features = ['area', 'perimeter', 'compactness', 'length',
'width', 'asymmetry coefficient', 'groove length']
for i, feature in enumerate(features):
sns.kdeplot(df[feature], label=feature, ax=ax1)
ax1.set_title('特征分布密度图', fontsize=14, fontweight='bold')
ax1.set_xlabel('特征值', fontsize=12)
ax1.set_ylabel('密度', fontsize=12)
ax1.legend(fontsize=9)
ax1.grid(alpha=0.3)
# 2. 类别分布饼图
ax2 = plt.subplot(3, 3, 2)
category_counts = df['category'].value_counts()
category_names = {1.0: 'Kama', 2.0: 'Rosa', 3.0: 'Canadian'}
labels = [category_names.get(cat, f'类别{cat}') for cat in category_counts.index]
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']
wedges, texts, autotexts = ax2.pie(category_counts.values, labels=labels,
autopct='%1.1f%%', colors=colors,
startangle=90)
ax2.set_title('小麦品种分布', fontsize=14, fontweight='bold')
plt.setp(autotexts, size=10, weight="bold")
# 3. 特征箱线图
ax3 = plt.subplot(3, 3, 3)
df_melted = pd.melt(df[features], var_name='特征', value_name='值')
sns.boxplot(x='特征', y='值', data=df_melted, ax=ax3)
ax3.set_title('特征箱线图(检测异常值)', fontsize=14, fontweight='bold')
ax3.set_xlabel('特征', fontsize=12)
ax3.set_ylabel('特征值', fontsize=12)
ax3.tick_params(axis='x', rotation=45)
ax3.grid(alpha=0.3)
# 4. 特征相关性热图
ax4 = plt.subplot(3, 3, 4)
correlation_matrix = df[features + ['category']].corr()
sns.heatmap(correlation_matrix, annot=True, fmt='.2f', cmap='coolwarm',
center=0, ax=ax4, square=True, cbar_kws={"shrink": 0.8})
ax4.set_title('特征相关性热图', fontsize=14, fontweight='bold')
plt.tight_layout()
# 5. 类别特征分布(小提琴图)
ax5 = plt.subplot(3, 3, 5)
sample_features = ['area', 'length', 'width']
df_melted2 = pd.melt(df[['category'] + sample_features],
id_vars=['category'],
value_vars=sample_features,
var_name='特征', value_name='值')
sns.violinplot(x='特征', y='值', hue='category', data=df_melted2,
split=True, inner="quart", palette=colors, ax=ax5)
ax5.set_title('不同品种特征分布(小提琴图)', fontsize=14, fontweight='bold')
ax5.set_xlabel('特征', fontsize=12)
ax5.set_ylabel('特征值', fontsize=12)
ax5.legend(title='类别', loc='upper right')
ax5.grid(alpha=0.3)
# 6. 散点矩阵图(前4个特征)
ax6 = plt.subplot(3, 3, 6)
from pandas.plotting import scatter_matrix
scatter_features = features[:4]
scatter_df = df[scatter_features + ['category']]
# 7. Q-Q图(正态性检验)
ax7 = plt.subplot(3, 3, 7)
stats.probplot(df['area'], dist="norm", plot=ax7)
ax7.set_title('面积特征的Q-Q图(正态性检验)', fontsize=14, fontweight='bold')
ax7.set_xlabel('理论分位数', fontsize=12)
ax7.set_ylabel('样本分位数', fontsize=12)
ax7.grid(alpha=0.3)
# 8. 累计分布函数
ax8 = plt.subplot(3, 3, 8)
for feature in ['area', 'perimeter', 'length']:
sorted_data = np.sort(df[feature])
y_vals = np.arange(len(sorted_data)) / float(len(sorted_data))
ax8.plot(sorted_data, y_vals, label=feature, linewidth=2)
ax8.set_title('累计分布函数(CDF)', fontsize=14, fontweight='bold')
ax8.set_xlabel('特征值', fontsize=12)
ax8.set_ylabel('累计概率', fontsize=12)
ax8.legend(fontsize=10)
ax8.grid(alpha=0.3)
# 9. 统计摘要表
ax9 = plt.subplot(3, 3, 9)
ax9.axis('tight')
ax9.axis('off')
# 计算基本统计量
stats_summary = pd.DataFrame({
'统计量': ['样本总数', '特征数量', '类别数量', 'Kama样本数',
'Rosa样本数', 'Canadian样本数', '数据完整性'],
'值': [len(df), len(features), 3,
len(df[df['category'] == 1.0]),
len(df[df['category'] == 2.0]),
len(df[df['category'] == 3.0]),
'100% (无缺失值)']
})
table = ax9.table(cellText=stats_summary.values,
colLabels=stats_summary.columns,
cellLoc='center',
loc='center',
colWidths=[0.25, 0.25])
table.auto_set_font_size(False)
table.set_fontsize(9)
table.scale(1, 2)
for i in range(len(stats_summary.columns)):
table[(0, i)].set_facecolor('#40466e')
table[(0, i)].set_text_props(weight='bold', color='white')
ax9.set_title('数据集统计摘要', fontsize=14, fontweight='bold', y=0.98)
plt.suptitle('小麦品种数据集探索性分析', fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()
# 打印详细统计信息
print("\n详细统计信息:")
print("="*50)
# 按类别统计
print("\n按类别统计:")
for cat, name in category_names.items():
cat_df = df[df['category'] == cat]
print(f"\n{name} (类别{cat}):")
print(f" 样本数: {len(cat_df)}")
print(f" 特征均值:")
for feature in features:
print(f" {feature}: {cat_df[feature].mean():.4f}")
# 简化版的散点图矩阵
sns.pairplot(data=scatter_df, hue='category', palette=colors,
diag_kind='kde', plot_kws={'alpha': 0.6})
plt.suptitle('特征散点矩阵图(前4个特征)', fontsize=14, y=1.02)
return df
# 执行EDA分析
df = perform_eda_analysis(df)



2.3 特征工程
# ==============================================
# 4. 特征工程
# ==============================================
class WheatFeatureEngineer:
"""小麦特征工程类"""
def __init__(self, df):
self.df = df.copy()
self.features = None
def create_basic_features(self):
"""创建基础特征"""
print("创建基础特征...")
# 分离特征和目标变量
features = ['area', 'perimeter', 'compactness', 'length',
'width', 'asymmetry coefficient', 'groove length']
X = self.df[features].copy()
y = self.df['category'].copy()
print(f"原始特征数量: {len(features)}")
print(f"特征列表: {features}")
return X, y
def create_interaction_features(self, X):
"""创建交互特征"""
print("\n创建交互特征...")
# 面积相关交互特征
X['area_perimeter_ratio'] = X['area'] / X['perimeter']
X['area_length_ratio'] = X['area'] / X['length']
X['area_width_ratio'] = X['area'] / X['width']
# 形状特征
X['length_width_ratio'] = X['length'] / X['width']
X['compactness_area'] = X['compactness'] * X['area']
# 综合特征
X['shape_index'] = (X['length'] * X['width']) / X['area']
X['volume_approx'] = X['area'] * X['groove length'] / 2
print(f"交互特征数量: {6}")
print(f"总特征数量: {X.shape[1]}")
return X
def create_polynomial_features(self, X):
"""创建多项式特征(选择性的)"""
print("\n创建多项式特征...")
# 选择重要特征创建二次项
important_features = ['area', 'length', 'width', 'groove length']
for feature in important_features:
X[f'{feature}_squared'] = X[feature] ** 2
X[f'{feature}_sqrt'] = np.sqrt(X[feature])
print(f"多项式特征数量: {len(important_features) * 2}")
print(f"总特征数量: {X.shape[1]}")
return X
def statistical_analysis(self, X):
"""特征统计分析"""
print("\n特征统计分析:")
print("="*50)
# 计算每个特征的统计量
stats_df = pd.DataFrame({
'均值': X.mean(),
'标准差': X.std(),
'最小值': X.min(),
'25%分位数': X.quantile(0.25),
'中位数': X.median(),
'75%分位数': X.quantile(0.75),
'最大值': X.max(),
'偏度': X.skew(),
'峰度': X.kurtosis()
})
print("\n特征统计摘要:")
print(stats_df.round(4))
# 检测异常值
print("\n异常值检测:")
for col in X.columns:
Q1 = X[col].quantile(0.25)
Q3 = X[col].quantile(0.75)
IQR = Q3 - Q1
lower_bound = Q1 - 1.5 * IQR
upper_bound = Q3 + 1.5 * IQR
outliers = X[(X[col] < lower_bound) | (X[col] > upper_bound)]
if len(outliers) > 0:
print(f" {col}: {len(outliers)}个异常值 ({len(outliers)/len(X)*100:.1f}%)")
return X
def feature_selection(self, X, y, k=10):
"""特征选择"""
print("\n特征选择...")
# 使用ANOVA F-value进行特征选择
selector = SelectKBest(score_func=f_classif, k=min(k, X.shape[1]))
X_selected = selector.fit_transform(X, y)
# 获取特征得分
feature_scores = pd.DataFrame({
'特征': X.columns,
'F值': selector.scores_,
'P值': selector.pvalues_
}).sort_values('F值', ascending=False)
print("\n特征重要性排序(基于ANOVA F值):")
print(feature_scores.round(4))
selected_features = feature_scores['特征'].head(k).tolist()
X_selected_df = X[selected_features].copy()
print(f"\n选择的特征 ({k}个):")
print(selected_features)
return X_selected_df, selected_features
def visualize_features(self, X, y):
"""特征可视化"""
print("\n特征可视化分析...")
# 创建可视化图表
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
# 1. 特征重要性条形图
ax1 = axes[0, 0]
# 计算特征与目标的相关性
correlations = []
for col in X.columns:
if len(np.unique(X[col])) > 1:
f_val, p_val = f_classif(X[[col]], y)
correlations.append(abs(f_val[0]))
else:
correlations.append(0)
feature_importance = pd.DataFrame({
'特征': X.columns,
'重要性': correlations
}).sort_values('重要性', ascending=True)
bars = ax1.barh(range(len(feature_importance)), feature_importance['重要性'].values,
color=plt.cm.viridis(np.linspace(0.3, 0.9, len(feature_importance))))
ax1.set_yticks(range(len(feature_importance)))
ax1.set_yticklabels(feature_importance['特征'])
ax1.set_title('特征重要性排序(F值)', fontsize=14, fontweight='bold')
ax1.set_xlabel('F值(绝对值)', fontsize=12)
ax1.invert_yaxis()
ax1.grid(axis='x', alpha=0.3)
# 2. 特征相关性热图
ax2 = axes[0, 1]
correlation_matrix = X.corr()
sns.heatmap(correlation_matrix, annot=True, fmt='.2f', cmap='coolwarm',
center=0, ax=ax2, square=True, cbar_kws={"shrink": 0.8})
ax2.set_title('特征相关性热图', fontsize=14, fontweight='bold')
# 3. PCA降维可视化
ax3 = axes[1, 0]
from sklearn.decomposition import PCA
# 标准化数据
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# PCA降维到2维
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X_scaled)
# 创建颜色映射
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']
color_map = {1.0: colors[0], 2.0: colors[1], 3.0: colors[2]}
for category in np.unique(y):
mask = y == category
ax3.scatter(X_pca[mask, 0], X_pca[mask, 1],
color=color_map[category], alpha=0.6,
label=f'类别{int(category)}', s=50)
ax3.set_xlabel(f'主成分1 (解释方差: {pca.explained_variance_ratio_[0]:.2%})', fontsize=12)
ax3.set_ylabel(f'主成分2 (解释方差: {pca.explained_variance_ratio_[1]:.2%})', fontsize=12)
ax3.set_title('PCA降维可视化', fontsize=14, fontweight='bold')
ax3.legend(fontsize=10)
ax3.grid(alpha=0.3)
# 4. 特征分布箱线图(按类别)
ax4 = axes[1, 1]
# 选择最重要的4个特征
top_features = feature_importance['特征'].tail(4).tolist()
data_to_plot = []
for feature in top_features:
for category in np.unique(y):
values = X[y == category][feature].values
data_to_plot.append({
'特征': feature,
'类别': int(category),
'值': values
})
# 创建DataFrame用于绘图
plot_data = pd.DataFrame(data_to_plot)
plot_data = plot_data.explode('值')
sns.boxplot(x='特征', y='值', hue='类别', data=plot_data,
palette=colors, ax=ax4)
ax4.set_title('重要特征分布(按类别)', fontsize=14, fontweight='bold')
ax4.set_xlabel('特征', fontsize=12)
ax4.set_ylabel('特征值', fontsize=12)
ax4.legend(title='小麦品种', fontsize=9)
ax4.grid(alpha=0.3)
plt.suptitle('特征工程分析结果', fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()
# 打印PCA解释方差
print(f"\nPCA分析结果:")
print(f"主成分1解释方差: {pca.explained_variance_ratio_[0]:.4f}")
print(f"主成分2解释方差: {pca.explained_variance_ratio_[1]:.4f}")
print(f"累计解释方差: {pca.explained_variance_ratio_[:2].sum():.4f}")
return X
def prepare_data_for_modeling(self):
"""准备建模数据"""
print("\n" + "="*60)
print("准备建模数据")
print("="*60)
# 创建基础特征
X, y = self.create_basic_features()
# 创建交互特征
X = self.create_interaction_features(X)
# 创建多项式特征
X = self.create_polynomial_features(X)
# 特征统计分析
X = self.statistical_analysis(X)
# 特征选择(选择最重要的10个特征)
X_selected, selected_features = self.feature_selection(X, y, k=10)
# 特征可视化
self.visualize_features(X_selected, y)
# 数据标准化
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_selected)
self.features = X_scaled
self.target = y.values
self.feature_names = selected_features
print(f"\n最终数据形状:")
print(f"特征矩阵: {X_scaled.shape}")
print(f"目标变量: {y.shape}")
return X_scaled, y.values, selected_features
# 执行特征工程
feature_engineer = WheatFeatureEngineer(df)
X, y, feature_names = feature_engineer.prepare_data_for_modeling()



2.4 逻辑回归模型训练与评估
# ==============================================
# 5. 逻辑回归模型训练与评估
# ==============================================
class WheatLogisticRegressionModel:
"""小麦逻辑回归模型类"""
def __init__(self):
self.model = None
self.scaler = StandardScaler()
self.best_params = None
def train_model(self, X, y, test_size=0.2):
"""训练逻辑回归模型"""
print("="*60)
print("逻辑回归模型训练")
print("="*60)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=test_size, random_state=42, stratify=y
)
print(f"数据集划分:")
print(f"训练集: X={X_train.shape}, y={y_train.shape}")
print(f"测试集: X={X_test.shape}, y={y_test.shape}")
# 数据标准化
X_train_scaled = self.scaler.fit_transform(X_train)
X_test_scaled = self.scaler.transform(X_test)
# 超参数调优
print("\n超参数调优...")
best_score = 0
best_params = {}
# 简单的网格搜索
param_grid = {
'C': [0.01, 0.1, 1, 10, 100],
'solver': ['lbfgs', 'newton-cg', 'sag'],
'max_iter': [100, 200, 500]
}
for C in param_grid['C']:
for solver in param_grid['solver']:
for max_iter in param_grid['max_iter']:
try:
model = LogisticRegression(
C=C,
solver=solver,
max_iter=max_iter,
multi_class='multinomial',
random_state=42
)
# 使用交叉验证评估
cv_scores = cross_val_score(model, X_train_scaled, y_train,
cv=5, scoring='accuracy')
mean_score = cv_scores.mean()
if mean_score > best_score:
best_score = mean_score
best_params = {
'C': C,
'solver': solver,
'max_iter': max_iter
}
print(f" C={C}, solver={solver}, max_iter={max_iter}: "
f"CV准确率={mean_score:.4f}")
except:
continue
print(f"\n最佳参数: {best_params}")
print(f"最佳交叉验证准确率: {best_score:.4f}")
# 使用最佳参数训练模型
print("\n使用最佳参数训练模型...")
self.model = LogisticRegression(
C=best_params['C'],
solver=best_params['solver'],
max_iter=best_params['max_iter'],
multi_class='multinomial',
random_state=42
)
self.model.fit(X_train_scaled, y_train)
self.best_params = best_params
# 在测试集上评估
y_pred = self.model.predict(X_test_scaled)
y_pred_proba = self.model.predict_proba(X_test_scaled)
# 计算评估指标
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred, average='weighted')
recall = recall_score(y_test, y_pred, average='weighted')
f1 = f1_score(y_test, y_pred, average='weighted')
print(f"\n测试集评估结果:")
print(f"准确率: {accuracy:.4f}")
print(f"精确率: {precision:.4f}")
print(f"召回率: {recall:.4f}")
print(f"F1分数: {f1:.4f}")
return X_test_scaled, y_test, y_pred, y_pred_proba
def evaluate_model(self, X_test, y_test, y_pred, y_pred_proba):
"""评估模型性能"""
print("\n" + "="*60)
print("模型详细评估")
print("="*60)
# 1. 混淆矩阵
cm = confusion_matrix(y_test, y_pred)
print("\n混淆矩阵:")
print(cm)
# 2. 分类报告
print("\n分类报告:")
target_names = ['Kama (1)', 'Rosa (2)', 'Canadian (3)']
print(classification_report(y_test, y_pred, target_names=target_names))
# 创建可视化图表
fig = plt.figure(figsize=(15, 10))
# 1. 混淆矩阵热图
ax1 = plt.subplot(2, 3, 1)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax1,
xticklabels=target_names,
yticklabels=target_names)
ax1.set_title('混淆矩阵热图', fontsize=14, fontweight='bold')
ax1.set_xlabel('预测标签', fontsize=12)
ax1.set_ylabel('真实标签', fontsize=12)
# 2. 各类别准确率
ax2 = plt.subplot(2, 3, 2)
class_accuracy = cm.diagonal() / cm.sum(axis=1)
bars = ax2.bar(range(3), class_accuracy, color=['#FF6B6B', '#4ECDC4', '#45B7D1'])
ax2.set_title('各类别准确率', fontsize=14, fontweight='bold')
ax2.set_xlabel('小麦品种', fontsize=12)
ax2.set_ylabel('准确率', fontsize=12)
ax2.set_xticks(range(3))
ax2.set_xticklabels(target_names)
ax2.set_ylim([0, 1])
ax2.grid(axis='y', alpha=0.3)
# 添加准确率标签
for bar, acc in zip(bars, class_accuracy):
height = bar.get_height()
ax2.text(bar.get_x() + bar.get_width()/2., height + 0.02,
f'{acc:.3f}', ha='center', va='bottom', fontweight='bold')
# 3. ROC曲线(多分类)
ax3 = plt.subplot(2, 3, 3)
from sklearn.preprocessing import label_binarize
# 将标签二值化
y_test_bin = label_binarize(y_test, classes=[1, 2, 3])
n_classes = y_test_bin.shape[1]
# 计算每个类别的ROC曲线
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
fpr[i], tpr[i], _ = roc_curve(y_test_bin[:, i], y_pred_proba[:, i])
roc_auc[i] = roc_auc_score(y_test_bin[:, i], y_pred_proba[:, i])
# 绘制每个类别的ROC曲线
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']
for i, color in zip(range(n_classes), colors):
ax3.plot(fpr[i], tpr[i], color=color, lw=2,
label=f'{target_names[i]} (AUC = {roc_auc[i]:.3f})')
ax3.plot([0, 1], [0, 1], 'k--', lw=2)
ax3.set_xlim([0.0, 1.0])
ax3.set_ylim([0.0, 1.05])
ax3.set_xlabel('假正率 (FPR)', fontsize=12)
ax3.set_ylabel('真正率 (TPR)', fontsize=12)
ax3.set_title('多分类ROC曲线', fontsize=14, fontweight='bold')
ax3.legend(loc="lower right", fontsize=10)
ax3.grid(alpha=0.3)
# 4. 特征重要性(系数绝对值)
ax4 = plt.subplot(2, 3, 4)
if hasattr(self.model, 'coef_'):
feature_importance = np.abs(self.model.coef_).mean(axis=0)
importance_df = pd.DataFrame({
'特征': feature_names,
'重要性': feature_importance
}).sort_values('重要性', ascending=True).tail(10)
bars = ax4.barh(range(len(importance_df)), importance_df['重要性'].values,
color=plt.cm.viridis(np.linspace(0.3, 0.9, len(importance_df))))
ax4.set_yticks(range(len(importance_df)))
ax4.set_yticklabels(importance_df['特征'])
ax4.set_title('逻辑回归特征重要性(Top 10)', fontsize=14, fontweight='bold')
ax4.set_xlabel('系数绝对值(平均)', fontsize=12)
ax4.invert_yaxis()
ax4.grid(axis='x', alpha=0.3)
# 5. 预测概率分布
ax5 = plt.subplot(2, 3, 5)
# 对每个类别的预测概率进行可视化
for i, color in enumerate(colors):
ax5.hist(y_pred_proba[y_test == i+1, i], bins=20, alpha=0.5,
color=color, label=target_names[i])
ax5.set_xlabel('预测概率', fontsize=12)
ax5.set_ylabel('频数', fontsize=12)
ax5.set_title('各类别预测概率分布', fontsize=14, fontweight='bold')
ax5.legend(fontsize=10)
ax5.grid(alpha=0.3)
# 6. 性能指标雷达图
ax6 = plt.subplot(2, 3, 6, polar=True)
# 计算每个类别的性能指标
metrics = ['准确率', '精确率', '召回率', 'F1分数']
metrics_values = [
class_accuracy.mean(),
precision_score(y_test, y_pred, average='macro'),
recall_score(y_test, y_pred, average='macro'),
f1_score(y_test, y_pred, average='macro')
]
angles = np.linspace(0, 2 * np.pi, len(metrics), endpoint=False).tolist()
metrics_values += metrics_values[:1] # 闭合图形
angles += angles[:1]
ax6.plot(angles, metrics_values, 'o-', linewidth=2, color='blue')
ax6.fill(angles, metrics_values, alpha=0.25, color='blue')
ax6.set_xticks(angles[:-1])
ax6.set_xticklabels(metrics, fontsize=10)
ax6.set_ylim([0.7, 1.0])
ax6.set_title('模型性能雷达图', fontsize=14, fontweight='bold', pad=20)
ax6.grid(True)
plt.suptitle('逻辑回归模型评估结果', fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()
# 打印模型系数
print("\n" + "="*60)
print("逻辑回归模型系数分析")
print("="*60)
if hasattr(self.model, 'coef_'):
print("\n各类别权重系数:")
for i, class_name in enumerate(target_names):
print(f"\n{class_name}:")
for j, feature in enumerate(feature_names):
coeff = self.model.coef_[i, j]
print(f" {feature}: {coeff:.6f}")
return {
'accuracy': accuracy_score(y_test, y_pred),
'precision': precision_score(y_test, y_pred, average='weighted'),
'recall': recall_score(y_test, y_pred, average='weighted'),
'f1': f1_score(y_test, y_pred, average='weighted'),
'confusion_matrix': cm
}
def cross_validation(self, X, y, cv=10):
"""交叉验证评估"""
print("\n" + "="*60)
print("10折交叉验证评估")
print("="*60)
# 数据标准化
X_scaled = self.scaler.fit_transform(X)
# 创建模型
model = LogisticRegression(
C=self.best_params['C'] if self.best_params else 1,
solver=self.best_params['solver'] if self.best_params else 'lbfgs',
max_iter=self.best_params['max_iter'] if self.best_params else 100,
multi_class='multinomial',
random_state=42
)
# 执行交叉验证
cv_scores = cross_val_score(model, X_scaled, y, cv=cv, scoring='accuracy')
print(f"\n交叉验证结果:")
print(f"准确率: {cv_scores.mean():.4f} (+/- {cv_scores.std() * 2:.4f})")
print(f"各折准确率: {cv_scores.round(4)}")
# 可视化交叉验证结果
plt.figure(figsize=(10, 6))
plt.plot(range(1, cv+1), cv_scores, 'o-', linewidth=2, markersize=8)
plt.axhline(y=cv_scores.mean(), color='r', linestyle='--',
label=f'平均准确率: {cv_scores.mean():.4f}')
plt.fill_between(range(1, cv+1),
cv_scores.mean() - cv_scores.std(),
cv_scores.mean() + cv_scores.std(),
alpha=0.2, color='gray')
plt.title(f'{cv}折交叉验证准确率', fontsize=14, fontweight='bold')
plt.xlabel('折数', fontsize=12)
plt.ylabel('准确率', fontsize=12)
plt.ylim([0.5, 1.0])
plt.legend(fontsize=10)
plt.grid(alpha=0.3)
plt.show()
return cv_scores
def predict_new_sample(self, sample_features):
"""预测新样本"""
if self.model is None:
print("错误: 模型未训练")
return None
# 标准化特征
sample_scaled = self.scaler.transform([sample_features])
# 预测
prediction = self.model.predict(sample_scaled)[0]
probabilities = self.model.predict_proba(sample_scaled)[0]
# 映射类别名称
category_names = {1: 'Kama', 2: 'Rosa', 3: 'Canadian'}
print(f"\n预测结果:")
print(f"预测类别: {category_names[prediction]} (类别{prediction})")
print(f"\n各类别概率:")
for i, prob in enumerate(probabilities):
print(f" {category_names[i+1]}: {prob:.4f}")
return prediction, probabilities
# 训练和评估逻辑回归模型
print("\n" + "="*60)
print("小麦品种分类 - 逻辑回归模型")
print("="*60)
# 创建模型实例
lr_model = WheatLogisticRegressionModel()
# 训练模型
X_test, y_test, y_pred, y_pred_proba = lr_model.train_model(X, y)
# 详细评估模型
metrics = lr_model.evaluate_model(X_test, y_test, y_pred, y_pred_proba)
# 交叉验证
cv_scores = lr_model.cross_validation(X, y, cv=10)
# 示例:预测新样本
print("\n" + "="*60)
print("新样本预测示例")
print("="*60)
# 创建一个示例样本(使用特征平均值)
sample_features = np.mean(X, axis=0)
prediction, probabilities = lr_model.predict_new_sample(sample_features)





2.5 模型结果对比
# ==============================================
# 6. 模型对比与部署
# ==============================================
def compare_classifiers(X, y):
"""比较不同分类器的性能"""
print("="*60)
print("不同分类器性能对比")
print("="*60)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# 数据标准化
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
# 定义要比较的分类器
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier
classifiers = {
'逻辑回归': LogisticRegression(multi_class='multinomial', random_state=42),
'K最近邻': KNeighborsClassifier(),
'支持向量机': SVC(probability=True, random_state=42),
'随机森林': RandomForestClassifier(random_state=42),
'朴素贝叶斯': GaussianNB(),
'决策树': DecisionTreeClassifier(random_state=42)
}
results = []
for name, clf in classifiers.items():
print(f"\n训练 {name}...")
# 训练模型
clf.fit(X_train_scaled, y_train)
# 预测
y_pred = clf.predict(X_test_scaled)
# 计算指标
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred, average='weighted')
recall = recall_score(y_test, y_pred, average='weighted')
f1 = f1_score(y_test, y_pred, average='weighted')
results.append({
'分类器': name,
'准确率': accuracy,
'精确率': precision,
'召回率': recall,
'F1分数': f1
})
print(f" 准确率: {accuracy:.4f}")
print(f" F1分数: {f1:.4f}")
# 创建结果DataFrame
results_df = pd.DataFrame(results)
results_df = results_df.sort_values('准确率', ascending=False)
# 可视化对比结果
# fig = plt.figure(figsize=(18, 12))
fig, axes = plt.subplots(1, 3, figsize=(40, 15))
# 1. 准确率对比
ax1 = plt.subplot(1, 3, 1)
bars1 = ax1.bar(range(len(results_df)), results_df['准确率'].values,
color=plt.cm.Set3(np.linspace(0, 1, len(results_df))))
ax1.set_title('不同分类器准确率对比', fontsize=14, fontweight='bold')
ax1.set_xlabel('分类器', fontsize=12)
ax1.set_ylabel('准确率', fontsize=12)
ax1.set_xticks(range(len(results_df)))
ax1.set_xticklabels(results_df['分类器'], rotation=45, ha='right')
ax1.set_ylim([0.7, 1.0])
ax1.grid(axis='y', alpha=0.3)
# 添加准确率标签
for bar, acc in zip(bars1, results_df['准确率']):
height = bar.get_height()
ax1.text(bar.get_x() + bar.get_width()/2., height + 0.005,
f'{acc:.4f}', ha='center', va='bottom', fontsize=9)
# 2. F1分数对比
ax2 = plt.subplot(1, 3, 2)
bars2 = ax2.bar(range(len(results_df)), results_df['F1分数'].values,
color=plt.cm.Set2(np.linspace(0, 1, len(results_df))))
ax2.set_title('不同分类器F1分数对比', fontsize=14, fontweight='bold')
ax2.set_xlabel('分类器', fontsize=12)
ax2.set_ylabel('F1分数', fontsize=12)
ax2.set_xticks(range(len(results_df)))
ax2.set_xticklabels(results_df['分类器'], rotation=45, ha='right')
ax2.set_ylim([0.7, 1.0])
ax2.grid(axis='y', alpha=0.3)
# 3. 训练时间对比(近似)
ax3 = plt.subplot(1, 3, 3)
training_times = [0.1, 0.05, 0.15, 0.2, 0.02, 0.03] # 示例值
bars3 = ax3.bar(range(len(results_df)), training_times,
color=plt.cm.Pastel1(np.linspace(0, 1, len(results_df))))
ax3.set_title('模型训练时间对比(相对)', fontsize=14, fontweight='bold')
ax3.set_xlabel('分类器', fontsize=12)
ax3.set_ylabel('相对训练时间', fontsize=12)
ax3.set_xticks(range(len(results_df)))
ax3.set_xticklabels(results_df['分类器'], rotation=45, ha='right')
ax3.grid(axis='y', alpha=0.3)
plt.suptitle('小麦品种分类模型性能对比分析', fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()
return results_df
# 执行模型对比
print("\n执行模型性能对比...")
results_df = compare_classifiers(X, y)
print("\n" + "="*60)
print("模型性能对比结果")
print("="*60)
print(results_df.round(4))



2.6 模型部署与保存
# ==============================================
# 7. 模型部署与保存
# ==============================================
import joblib
import json
from datetime import datetime
print("\n" + "="*60)
print("模型部署与保存")
print("="*60)
# 1. 保存训练好的模型
model_filename = "wheat_variety_logistic_regression.pkl"
joblib.dump(lr_model.model, model_filename)
print(f"✓ 逻辑回归模型已保存到: {model_filename}")
# 2. 保存标准化器
scaler_filename = "feature_scaler.pkl"
joblib.dump(lr_model.scaler, scaler_filename)
print(f"✓ 特征标准化器已保存到: {scaler_filename}")
# 3. 保存特征名称
feature_names_filename = "feature_names.json"
with open(feature_names_filename, 'w') as f:
json.dump(feature_names, f)
print(f"✓ 特征名称已保存到: {feature_names_filename}")
# 4. 保存模型元数据
model_metadata = {
'model_name': '小麦品种逻辑回归分类器',
'model_type': 'LogisticRegression',
'target_classes': {
1: 'Kama',
2: 'Rosa',
3: 'Canadian'
},
'features': feature_names,
'best_params': lr_model.best_params,
'performance': metrics,
'training_date': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
'dataset_info': {
'total_samples': len(df),
'features_count': len(feature_names),
'classes_count': 3
}
}
metadata_filename = "model_metadata.json"
with open(metadata_filename, 'w', encoding='utf-8') as f:
json.dump(model_metadata, f, indent=2, ensure_ascii=False)
print(f"✓ 模型元数据已保存到: {metadata_filename}")
# 5. 创建预测函数
def predict_wheat_variety(features_dict, model_path="wheat_variety_logistic_regression.pkl",
scaler_path="feature_scaler.pkl",
feature_names_path="feature_names.json"):
"""预测小麦品种"""
try:
# 加载模型和标准化器
model = joblib.load(model_path)
scaler = joblib.load(scaler_path)
# 加载特征名称
with open(feature_names_path, 'r') as f:
expected_features = json.load(f)
# 准备特征向量
features_vector = []
for feature in expected_features:
if feature in features_dict:
features_vector.append(features_dict[feature])
else:
raise ValueError(f"缺少特征: {feature}")
# 标准化特征
features_scaled = scaler.transform([features_vector])
# 预测
prediction = model.predict(features_scaled)[0]
probabilities = model.predict_proba(features_scaled)[0]
# 映射类别名称
category_names = {1: 'Kama', 2: 'Rosa', 3: 'Canadian'}
result = {
'predicted_class': int(prediction),
'predicted_variety': category_names[prediction],
'probabilities': {
category_names[i+1]: float(prob)
for i, prob in enumerate(probabilities)
},
'confidence': float(max(probabilities))
}
return result
except Exception as e:
return {'error': str(e)}
# 6. 示例:使用保存的模型进行预测
print("\n示例:使用保存的模型进行预测")
print("="*50)
# 创建一个测试样本
test_sample = {}
for i, feature in enumerate(feature_names):
test_sample[feature] = np.mean(X[:, i])
print(f"测试样本特征值:")
for feature, value in test_sample.items():
print(f" {feature}: {value:.4f}")
# 进行预测
prediction_result = predict_wheat_variety(test_sample)
if 'error' not in prediction_result:
print(f"\n预测结果:")
print(f" 预测类别: {prediction_result['predicted_class']}")
print(f" 预测品种: {prediction_result['predicted_variety']}")
print(f" 置信度: {prediction_result['confidence']:.4f}")
print(f"\n各类别概率:")
for variety, prob in prediction_result['probabilities'].items():
print(f" {variety}: {prob:.4f}")
else:
print(f"预测错误: {prediction_result['error']}")

三、项目核心步骤分析
3.1 数据探索分析(EDA)
- 数据质量检查:验证数据集完整性,确保无缺失值和重复值
- 分布可视化:通过密度图、箱线图、Q-Q图分析特征分布
- 相关性分析:计算特征间的Pearson相关系数,识别高度相关特征
- 类别差异分析:通过小提琴图展示不同小麦品种的特征分布差异
3.2 特征工程流程
- 基础特征提取:使用原始7个物理属性特征
- 交互特征创建:计算面积周长比、长宽比、形状指数等新特征
- 多项式特征:添加特征的平方和平方根,捕捉非线性关系
- 特征选择:使用ANOVA F值选择最重要的10个特征
3.3 模型训练与优化
- 逻辑回归模型:使用多分类逻辑回归作为基础模型
- 超参数调优:网格搜索优化C值、求解器和最大迭代次数
- 交叉验证:10折交叉验证确保模型稳定性
- 性能评估:计算准确率、精确率、召回率、F1分数等多个指标
3.4 模型部署
- 模型保存:使用joblib保存训练好的模型和标准化器
- 元数据管理:保存特征名称、模型参数和性能指标
- 预测函数:封装预测逻辑,支持新样本的分类预测
四、关键技术要点
4.1 数据预处理技巧
# 数据标准化
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_selected)
# 类别平衡处理
train_test_split(X, y, stratify=y)
4.2 特征工程策略
- 交互特征:捕捉特征间的相互关系
- 多项式特征:增强模型的非线性拟合能力
- 特征选择:基于统计显著性筛选重要特征
4.3 模型评估方法
- 混淆矩阵:直观展示分类结果
- ROC曲线:评估各类别的分类性能
- 交叉验证:减少过拟合风险
- 多指标评估:全面评价模型性能
五、运行结果与性能分析
5.1 模型性能指标
- 准确率:通常在90%以上
- F1分数:衡量精确率和召回率的平衡
- ROC AUC:各类别AUC值均接近1.0
- 交叉验证稳定性:10折CV准确率方差小
5.2 关键发现
- 面积和长度是最重要的区分特征
- Kama品种在紧凑度上与其他品种有明显差异
- 逻辑回归在小数据集上表现优秀,且解释性强
- 特征工程显著提升了模型性能
六、实际应用建议
6.1 数据采集建议
- 测量标准化:确保测量工具和方法的统一性
- 样本代表性:采集不同生长条件下的小麦样本
- 质量控制:定期校准测量设备
6.2 模型优化方向
- 集成学习:尝试随机森林、梯度提升等集成方法
- 深度学习:使用神经网络处理更复杂的非线性关系
- 在线学习:实现模型的增量学习和更新
6.3 部署注意事项
- 环境依赖:确保部署环境与训练环境的一致性
- 监控维护:定期监控模型性能,及时更新模型
- 用户界面:为农业技术人员提供友好的预测界面
七、总结
本文实现了一个完整的小麦品种分类机器学习项目,从数据探索到模型部署的全流程。项目展示了如何使用经典的机器学习方法解决农业领域的实际问题,具有以下特点:
- 流程完整:覆盖了机器学习项目的所有关键步骤
- 代码规范:采用模块化设计,易于理解和复用
- 可视化丰富:提供了全面的数据分析和结果展示
- 实用性强:可直接应用于实际的小麦品种识别场景
该项目不仅对农业领域的研究者有参考价值,也为机器学习初学者提供了一个优秀的学习案例。通过调整特征工程和模型参数,可以轻松扩展到其他农作物的品种识别任务中。
注: 博主目前收集了6900+份相关数据集,有想要的可以领取部分数据,关注下方公众号或添加微信:


更多推荐

所有评论(0)