图像分类

此示例说明如何使用深度网络设计器创建简单的卷积神经网络来进行深度学习分类。卷积神经网络是深度学习的基本工具,尤其适用于图像识别。

imageDatastore 函数根据文件夹名称自动对图像加标签。数据集有 10 个类,数据集中每个图像的像素数为 28×28×1。

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
    'nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true,'LabelSource','foldernames');

>> labelCount = countEachLabel(imds)

labelCount =

  10×2 table

    Label    Count
    _____    _____

      0      1000 
      1      1000 
      2      1000 
      3      1000 
      4      1000 
      5      1000 
      6      1000 
      7      1000 
      8      1000 
      9      1000 

在这里插入图片描述
将数据划分为训练、验证和测试数据集。将 70% 的图像用于训练,15% 的图像用于验证,15% 的图像用于测试。指定 “randomized” 以将每个类中指定比例的文件分配给新数据集。splitEachLabel 函数将图像数据存储拆分为三个新数据存储。

[imdsTrain,imdsValidation,imdsTest] = splitEachLabel(imds,0.7,0.15,0.15,"randomized");

构建网络

layers = [
    imageInputLayer([size(img,1) size(img,2) 1])
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    fullyConnectedLayer(10)
    softmaxLayer];

net_1 = dlnetwork(layers);

在这里插入图片描述
指定训练选项。
在选项中进行选择需要经验分析。

options = trainingOptions("sgdm", ...
    MaxEpochs=4, ...
    ValidationData=imdsValidation, ...
    ValidationFrequency=30, ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=false);

训练神经网络
使用 trainnet 函数训练神经网络。由于目的是分类,因此使用交叉熵损失。

net = trainnet(imdsTrain,net_1,"crossentropy",options);

在这里插入图片描述

预测
使用 minibatchpredict 函数进行预测,并使用 scores2label 函数将分数转换为标签。
对于单标签分类,评估准确度。准确度是指正确预测的百分比。

YTestscores = minibatchpredict(net,imdsTest);
YTest = scores2label(YTestscores,classNames);
TTest = imdsTest.Labels;
accuracy = sum(YTest == TTest)/numel(TTest)
>> accuracy

accuracy =

    0.9867

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
可视化一些预测值。

numTestObservations = numel(imdsTest.Files);
idx = randi(numTestObservations,9,1);

figure
tiledlayout("flow")
for i = 1:9
    nexttile
    img = readimage(imdsTest,idx(i));
    imshow(img)
    title("Predicted Class: " + string(YValidation(idx(i))))
end

在这里插入图片描述

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐