基于MATLAB实现深度学习图像分类
基于MATLAB实现深度学习图像分类
·
一、环境配置与数据准备
1.1 环境要求
- MATLAB版本:R2021a及以上(需安装Deep Learning Toolbox)
- GPU支持:推荐NVIDIA CUDA兼容显卡(通过
gpuDevice验证)
1.2 数据组织结构
dataset/
├── train/
│ ├── cat/
│ └── dog/
└── validation/
├── cat/
└── dog/
1.3 数据加载与预处理
% 创建图像数据存储
imdsTrain = imageDatastore('dataset/train', ...
'IncludeSubfolders', true, ...
'LabelSource', 'foldernames');
imdsValidation = imageDatastore('dataset/validation', ...
'IncludeSubfolders', true, ...
'LabelSource', 'foldernames');
% 数据增强(随机旋转±20°,水平翻转)
augmenter = imageDataAugmenter(...
'RandRotation', [-20, 20], ...
'RandXReflection', true);
% 调整图像大小并增强
augimdsTrain = augmentedImageDatastore([227 227], imdsTrain, 'DataAugmentation', augmenter);
augimdsValidation = augmentedImageDatastore([227 227], imdsValidation);
二、模型构建策略
2.1 迁移学习(推荐方法)
% 加载预训练模型(AlexNet/ResNet-50/EfficientNet)
net = alexnet;
% 修改网络结构
lgraph = layerGraph(net);
newFCLayer = fullyConnectedLayer(2, 'Name', 'fc_new', 'WeightLearnRateFactor', 10);
newOutputLayer = classificationLayer('Name', 'output_new');
% 替换最后两层
lgraph = replaceLayer(lgraph, 'fc7', newFCLayer);
lgraph = replaceLayer(lgraph, 'ClassificationLayer_fc7', newOutputLayer);
2.2 自定义CNN架构
layers = [
imageInputLayer([227 227 3])
% 卷积块1
convolution2dLayer(3, 32, 'Padding', 'same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2, 'Stride', 2)
% 卷积块2
convolution2dLayer(3, 64, 'Padding', 'same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2, 'Stride', 2)
% 全连接层
fullyConnectedLayer(64)
reluLayer
dropoutLayer(0.5)
% 输出层
fullyConnectedLayer(2)
softmaxLayer
classificationLayer];
三、模型训练与调优
3.1 训练参数配置
options = trainingOptions('adam', ...
'MaxEpochs', 20, ...
'MiniBatchSize', 64, ...
'InitialLearnRate', 0.001, ...
'Shuffle', 'every-epoch', ...
'ValidationData', augimdsValidation, ...
'ValidationFrequency', 30, ...
'Verbose', false, ...
'Plots', 'training-progress', ...
'ExecutionEnvironment', 'multi-gpu'); % 支持多GPU加速
3.2 模型训练
[netTrained, info] = trainNetwork(augimdsTrain, lgraph, options);
3.3 性能评估
% 验证集预测
YPred = classify(netTrained, augimdsValidation);
YValidation = imdsValidation.Labels;
% 计算准确率
accuracy = mean(YPred == YValidation);
fprintf('Validation Accuracy: %.2f%%
', accuracy*100);
% 混淆矩阵
cm = confusionchart(YValidation, YPred);
cm.Title = 'Confusion Matrix';
cm.ColumnSummary = 'column-normalized';
四、实战案例:花卉分类
5.1 数据集准备
下载并解压Oxford 102 Flowers数据集,按类别组织文件夹。
5.2 完整代码
% 加载数据
[imdsTrain, imdsValidation] = loadFlowerDataset();
% 数据增强
augmenter = imageDataAugmenter('RandRotation', [-15,15]);
augimdsTrain = augmentedImageDatastore([227 227], imdsTrain, 'DataAugmentation', augmenter);
% 迁移学习
net = alexnet;
lgraph = layerGraph(net);
layers = [lgraph.Layers(1:end-3) ... % 移除最后3层
fullyConnectedLayer(102, 'WeightLearnRateFactor', 10) ...
softmaxLayer ...
classificationLayer];
% 训练配置
options = trainingOptions('sgdm', ...
'MaxEpochs', 15, ...
'MiniBatchSize', 32, ...
'InitialLearnRate', 0.001, ...
'ExecutionEnvironment', 'gpu');
% 开始训练
netTrained = trainNetwork(augimdsTrain, lgraph, options);
% 评估模型
YPred = classify(netTrained, imdsValidation);
accuracy = mean(YPred == imdsValidation.Labels);
五、模型部署
6.1 MATLAB实时推理
% 加载测试图像
img = imread('test_flower.jpg');
imgResized = imresize(img, [227 227]);
% 预测
label = classify(netTrained, imgResized);
imshow(img);
title(sprintf('Predicted: %s (%.2f%%)', label, max(scores)*100));
6.2 生成TFLite模型
converter = dlquantizer(netTrained, 'Target', 'TensorFlow Lite');
converter.Optimize = true;
converter.Precision = 'int8';
tfliteModel = convert(converter);
save('flower_classifier.tflite', 'tfliteModel');
十、参考
-
MathWorks官方文档:Deep Learning in MATLAB] ww2.mathworks.cn/help/deeplearning/
-
代码 运用深度学习模型实现图像的分类 www.3dddown.com/csa/55199.html
-
AlexNet迁移学习示例:Image Category Classification ww2.mathworks.cn/help/deeplearning/ug/image-category-classification-using-deep-learning.html
更多推荐
所有评论(0)