MNIST数据集上构建简单的卷积神经网络、训练并分类(MATLAB例)
将 70% 的图像用于训练,15% 的图像用于验证,15% 的图像用于测试。指定 “randomized” 以将每个类中指定比例的文件分配给新数据集。splitEachLabel 函数将图像数据存储拆分为三个新数据存储。数据集有 10 个类,数据集中每个图像的像素数为 28×28×1。此示例说明如何使用深度网络设计器创建简单的卷积神经网络来进行深度学习分类。使用 minibatchpredict
·
此示例说明如何使用深度网络设计器创建简单的卷积神经网络来进行深度学习分类。卷积神经网络是深度学习的基本工具,尤其适用于图像识别。
imageDatastore 函数根据文件夹名称自动对图像加标签。数据集有 10 个类,数据集中每个图像的像素数为 28×28×1。
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
'nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders',true,'LabelSource','foldernames');
>> labelCount = countEachLabel(imds)
labelCount =
10×2 table
Label Count
_____ _____
0 1000
1 1000
2 1000
3 1000
4 1000
5 1000
6 1000
7 1000
8 1000
9 1000

将数据划分为训练、验证和测试数据集。将 70% 的图像用于训练,15% 的图像用于验证,15% 的图像用于测试。指定 “randomized” 以将每个类中指定比例的文件分配给新数据集。splitEachLabel 函数将图像数据存储拆分为三个新数据存储。
[imdsTrain,imdsValidation,imdsTest] = splitEachLabel(imds,0.7,0.15,0.15,"randomized");
构建网络
layers = [
imageInputLayer([size(img,1) size(img,2) 1])
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
fullyConnectedLayer(10)
softmaxLayer];
net_1 = dlnetwork(layers);

指定训练选项。
在选项中进行选择需要经验分析。
options = trainingOptions("sgdm", ...
MaxEpochs=4, ...
ValidationData=imdsValidation, ...
ValidationFrequency=30, ...
Plots="training-progress", ...
Metrics="accuracy", ...
Verbose=false);
训练神经网络
使用 trainnet 函数训练神经网络。由于目的是分类,因此使用交叉熵损失。
net = trainnet(imdsTrain,net_1,"crossentropy",options);

预测
使用 minibatchpredict 函数进行预测,并使用 scores2label 函数将分数转换为标签。
对于单标签分类,评估准确度。准确度是指正确预测的百分比。
YTestscores = minibatchpredict(net,imdsTest);
YTest = scores2label(YTestscores,classNames);
TTest = imdsTest.Labels;
accuracy = sum(YTest == TTest)/numel(TTest)
>> accuracy
accuracy =
0.9867



可视化一些预测值。
numTestObservations = numel(imdsTest.Files);
idx = randi(numTestObservations,9,1);
figure
tiledlayout("flow")
for i = 1:9
nexttile
img = readimage(imdsTest,idx(i));
imshow(img)
title("Predicted Class: " + string(YValidation(idx(i))))
end

更多推荐
所有评论(0)