一、环境配置与数据准备

1.1 环境要求
  • MATLAB版本:R2021a及以上(需安装Deep Learning Toolbox)
  • GPU支持:推荐NVIDIA CUDA兼容显卡(通过gpuDevice验证)
1.2 数据组织结构
dataset/
├── train/
│   ├── cat/
│   └── dog/
└── validation/
    ├── cat/
    └── dog/
1.3 数据加载与预处理
% 创建图像数据存储
imdsTrain = imageDatastore('dataset/train', ...
    'IncludeSubfolders', true, ...
    'LabelSource', 'foldernames');

imdsValidation = imageDatastore('dataset/validation', ...
    'IncludeSubfolders', true, ...
    'LabelSource', 'foldernames');

% 数据增强(随机旋转±20°,水平翻转)
augmenter = imageDataAugmenter(...
    'RandRotation', [-20, 20], ...
    'RandXReflection', true);

% 调整图像大小并增强
augimdsTrain = augmentedImageDatastore([227 227], imdsTrain, 'DataAugmentation', augmenter);
augimdsValidation = augmentedImageDatastore([227 227], imdsValidation);

二、模型构建策略

2.1 迁移学习(推荐方法)
% 加载预训练模型(AlexNet/ResNet-50/EfficientNet)
net = alexnet;

% 修改网络结构
lgraph = layerGraph(net);
newFCLayer = fullyConnectedLayer(2, 'Name', 'fc_new', 'WeightLearnRateFactor', 10);
newOutputLayer = classificationLayer('Name', 'output_new');

% 替换最后两层
lgraph = replaceLayer(lgraph, 'fc7', newFCLayer);
lgraph = replaceLayer(lgraph, 'ClassificationLayer_fc7', newOutputLayer);
2.2 自定义CNN架构
layers = [
    imageInputLayer([227 227 3])
    
    % 卷积块1
    convolution2dLayer(3, 32, 'Padding', 'same')
    batchNormalizationLayer
    reluLayer
    maxPooling2dLayer(2, 'Stride', 2)
    
    % 卷积块2
    convolution2dLayer(3, 64, 'Padding', 'same')
    batchNormalizationLayer
    reluLayer
    maxPooling2dLayer(2, 'Stride', 2)
    
    % 全连接层
    fullyConnectedLayer(64)
    reluLayer
    dropoutLayer(0.5)
    
    % 输出层
    fullyConnectedLayer(2)
    softmaxLayer
    classificationLayer];

三、模型训练与调优

3.1 训练参数配置
options = trainingOptions('adam', ...
    'MaxEpochs', 20, ...
    'MiniBatchSize', 64, ...
    'InitialLearnRate', 0.001, ...
    'Shuffle', 'every-epoch', ...
    'ValidationData', augimdsValidation, ...
    'ValidationFrequency', 30, ...
    'Verbose', false, ...
    'Plots', 'training-progress', ...
    'ExecutionEnvironment', 'multi-gpu'); % 支持多GPU加速
3.2 模型训练
[netTrained, info] = trainNetwork(augimdsTrain, lgraph, options);
3.3 性能评估
% 验证集预测
YPred = classify(netTrained, augimdsValidation);
YValidation = imdsValidation.Labels;

% 计算准确率
accuracy = mean(YPred == YValidation);
fprintf('Validation Accuracy: %.2f%%
', accuracy*100);

% 混淆矩阵
cm = confusionchart(YValidation, YPred);
cm.Title = 'Confusion Matrix';
cm.ColumnSummary = 'column-normalized';

四、实战案例:花卉分类

5.1 数据集准备

下载并解压Oxford 102 Flowers数据集,按类别组织文件夹。

5.2 完整代码
% 加载数据
[imdsTrain, imdsValidation] = loadFlowerDataset();

% 数据增强
augmenter = imageDataAugmenter('RandRotation', [-15,15]);
augimdsTrain = augmentedImageDatastore([227 227], imdsTrain, 'DataAugmentation', augmenter);

% 迁移学习
net = alexnet;
lgraph = layerGraph(net);
layers = [lgraph.Layers(1:end-3) ...  % 移除最后3层
    fullyConnectedLayer(102, 'WeightLearnRateFactor', 10) ...
    softmaxLayer ...
    classificationLayer];

% 训练配置
options = trainingOptions('sgdm', ...
    'MaxEpochs', 15, ...
    'MiniBatchSize', 32, ...
    'InitialLearnRate', 0.001, ...
    'ExecutionEnvironment', 'gpu');

% 开始训练
netTrained = trainNetwork(augimdsTrain, lgraph, options);

% 评估模型
YPred = classify(netTrained, imdsValidation);
accuracy = mean(YPred == imdsValidation.Labels);

五、模型部署

6.1 MATLAB实时推理
% 加载测试图像
img = imread('test_flower.jpg');
imgResized = imresize(img, [227 227]);

% 预测
label = classify(netTrained, imgResized);
imshow(img);
title(sprintf('Predicted: %s (%.2f%%)', label, max(scores)*100));
6.2 生成TFLite模型
converter = dlquantizer(netTrained, 'Target', 'TensorFlow Lite');
converter.Optimize = true;
converter.Precision = 'int8';
tfliteModel = convert(converter);
save('flower_classifier.tflite', 'tfliteModel');

十、参考

  1. MathWorks官方文档:Deep Learning in MATLAB] ww2.mathworks.cn/help/deeplearning/

  2. 代码 运用深度学习模型实现图像的分类 www.3dddown.com/csa/55199.html

  3. AlexNet迁移学习示例:Image Category Classification ww2.mathworks.cn/help/deeplearning/ug/image-category-classification-using-deep-learning.html

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐