基于神经网络的手写数字识别系统

结合模板匹配和神经网络两种方法进行手写数字识别。这个系统包括图像预处理、特征提取、神经网络训练和可视化分析。

%% 基于神经网络的手写数字识别系统

%% 清理工作区
clear; clc; close all;

%% 加载手写数字数据集
% 使用MATLAB自带的手写数字数据集
digitDatasetPath = fullfile(toolboxdir('nnet'), 'nndemos', 'nndatasets', 'DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders', true, ...
    'LabelSource', 'foldernames');

% 显示数据集信息
disp(['数据集包含 ', num2str(numel(imds.Files)), ' 张图像']);
countEachLabel(imds)

% 显示随机样本图像
figure('Name', '手写数字样本', 'Position', [100, 100, 800, 400]);
perm = randperm(numel(imds.Files), 20);
for i = 1:20
    subplot(4, 5, i);
    img = readimage(imds, perm(i));
    imshow(img);
    title(char(imds.Labels(perm(i))));
end
sgtitle('手写数字样本展示');

%% 数据预处理
% 将图像调整为28x28像素并转换为灰度图
processedImages = zeros(28, 28, 1, numel(imds.Files));
for i = 1:numel(imds.Files)
    img = readimage(imds, i);
    
    % 转换为灰度图
    if size(img, 3) == 3
        img = rgb2gray(img);
    end
    
    % 调整大小
    img = imresize(img, [28, 28]);
    
    % 归一化处理 [0, 1]
    img = im2double(img);
    
    % 图像二值化
    img = imbinarize(img);
    
    % 存储处理后的图像
    processedImages(:, :, 1, i) = img;
end

% 显示预处理后的图像
figure('Name', '预处理后的图像', 'Position', [100, 100, 800, 400]);
for i = 1:20
    subplot(4, 5, i);
    imshow(processedImages(:, :, 1, perm(i)));
    title(char(imds.Labels(perm(i))));
end
sgtitle('预处理后的手写数字');

%% 创建模板匹配系统
% 为每个数字创建平均模板
templates = zeros(28, 28, 10);
digitCounts = zeros(1, 10);

for i = 1:numel(imds.Files)
    digit = double(imds.Labels(i));
    templates(:, :, digit+1) = templates(:, :, digit+1) + processedImages(:, :, 1, i);
    digitCounts(digit+1) = digitCounts(digit+1) + 1;
end

% 计算平均模板
for i = 1:10
    if digitCounts(i) > 0
        templates(:, :, i) = templates(:, :, i) / digitCounts(i);
    end
end

% 显示模板
figure('Name', '数字模板', 'Position', [100, 100, 1000, 400]);
for i = 0:9
    subplot(2, 5, i+1);
    imshow(templates(:, :, i+1));
    title(['数字 ', num2str(i), ' 模板']);
end
sgtitle('模板匹配使用的数字模板');

%% 模板匹配测试
% 测试模板匹配的准确率
numTest = 200; % 测试样本数量
testIndices = randperm(numel(imds.Files), numTest);
templateResults = zeros(1, numTest);
templateCorrect = 0;

figure('Name', '模板匹配结果', 'Position', [100, 100, 1200, 600]);
colormap gray;

for i = 1:min(20, numTest) % 只显示前20个结果
    idx = testIndices(i);
    testImg = processedImages(:, :, 1, idx);
    trueLabel = double(imds.Labels(idx));
    
    % 计算与每个模板的相似度(使用相关系数)
    correlations = zeros(1, 10);
    for digit = 0:9
        corrMatrix = corrcoef(testImg(:), templates(:, :, digit+1)(:));
        correlations(digit+1) = corrMatrix(1, 2);
    end
    
    % 选择最相似的数字
    [~, predLabel] = max(correlations);
    predLabel = predLabel - 1;
    
    % 记录结果
    templateResults(i) = (predLabel == trueLabel);
    if predLabel == trueLabel
        templateCorrect = templateCorrect + 1;
    end
    
    % 显示结果
    subplot(4, 5, i);
    imshow(testImg);
    if predLabel == trueLabel
        title(sprintf('True: %d\nPred: %d', trueLabel, predLabel), 'Color', 'g');
    else
        title(sprintf('True: %d\nPred: %d', trueLabel, predLabel), 'Color', 'r');
    end
end

% 计算准确率
templateAccuracy = templateCorrect / numTest;
fprintf('模板匹配准确率: %.2f%%\n', templateAccuracy * 100);
sgtitle(sprintf('模板匹配结果 (准确率: %.2f%%)', templateAccuracy*100));

%% 准备神经网络数据
% 划分训练集和测试集 (70% 训练, 30% 测试)
[trainIdx, testIdx] = dividerand(numel(imds.Files), 0.7, 0.3);

% 创建训练集
XTrain = processedImages(:, :, :, trainIdx);
YTrain = categorical(imds.Labels(trainIdx));

% 创建测试集
XTest = processedImages(:, :, :, testIdx);
YTest = categorical(imds.Labels(testIdx));

% 显示数据集大小
fprintf('训练集大小: %d\n', numel(trainIdx));
fprintf('测试集大小: %d\n', numel(testIdx));

%% 构建神经网络模型
layers = [
    imageInputLayer([28 28 1], 'Name', 'input')
    
    convolution2dLayer(5, 32, 'Padding', 'same', 'Name', 'conv1')
    batchNormalizationLayer('Name', 'bn1')
    reluLayer('Name', 'relu1')
    
    maxPooling2dLayer(2, 'Stride', 2, 'Name', 'pool1')
    
    convolution2dLayer(3, 64, 'Padding', 'same', 'Name', 'conv2')
    batchNormalizationLayer('Name', 'bn2')
    reluLayer('Name', 'relu2')
    
    maxPooling2dLayer(2, 'Stride', 2, 'Name', 'pool2')
    
    fullyConnectedLayer(128, 'Name', 'fc1')
    reluLayer('Name', 'relu3')
    dropoutLayer(0.4, 'Name', 'dropout')
    
    fullyConnectedLayer(10, 'Name', 'fc2')
    softmaxLayer('Name', 'softmax')
    classificationLayer('Name', 'output')
];

% 可视化网络结构
figure('Name', '神经网络结构');
plot(layerGraph(layers));
title('卷积神经网络结构');

%% 设置训练选项
options = trainingOptions('adam', ...
    'InitialLearnRate', 0.001, ...
    'MaxEpochs', 15, ...
    'MiniBatchSize', 128, ...
    'Shuffle', 'every-epoch', ...
    'ValidationData', {XTest, YTest}, ...
    'ValidationFrequency', 30, ...
    'Verbose', true, ...
    'Plots', 'training-progress', ...
    'ExecutionEnvironment', 'auto');

%% 训练神经网络
disp('开始训练神经网络...');
net = trainNetwork(XTrain, YTrain, layers, options);
disp('神经网络训练完成!');

%% 评估神经网络性能
% 在整个测试集上进行预测
YPred = classify(net, XTest);

% 计算准确率
accuracy = sum(YPred == YTest) / numel(YTest);
fprintf('神经网络测试准确率: %.2f%%\n', accuracy * 100);

% 混淆矩阵
figure('Name', '混淆矩阵', 'Position', [100, 100, 800, 700]);
cm = confusionmat(YTest, YPred);
confusionchart(cm, categories(YTest));
title(sprintf('混淆矩阵 (准确率: %.2f%%)', accuracy*100));

%% 可视化神经网络预测结果
% 选择一些样本进行可视化
numSamplesToShow = 20;
testSampleIndices = randperm(numel(testIdx), numSamplesToShow);

figure('Name', '神经网络预测结果', 'Position', [100, 100, 1200, 600]);
colormap gray;

for i = 1:numSamplesToShow
    idx = testIdx(testSampleIndices(i));
    img = processedImages(:, :, 1, idx);
    trueLabel = char(imds.Labels(idx));
    predLabel = char(YPred(testSampleIndices(i)));
    
    subplot(4, 5, i);
    imshow(img);
    
    if strcmp(trueLabel, predLabel)
        title(sprintf('True: %s\nPred: %s', trueLabel, predLabel), 'Color', 'g');
    else
        title(sprintf('True: %s\nPred: %s', trueLabel, predLabel), 'Color', 'r');
    end
end
sgtitle(sprintf('神经网络预测结果 (准确率: %.2f%%)', accuracy*100));

%% 特征可视化
% 提取卷积层的激活
conv1Activations = activations(net, XTest, 'conv1');
conv2Activations = activations(net, XTest, 'conv2');

% 显示卷积层特征图
sampleIdx = testSampleIndices(1); % 使用第一个测试样本
sampleImg = XTest(:, :, :, sampleIdx);

figure('Name', '卷积层特征可视化', 'Position', [100, 100, 1200, 800]);

% 原始图像
subplot(3, 1, 1);
imshow(sampleImg);
title('原始图像');

% 第一卷积层的特征图
subplot(3, 1, 2);
montage(reshape(conv1Activations(:, :, :, sampleIdx), [28, 28]));
title('第一卷积层特征图');

% 第二卷积层的特征图
subplot(3, 1, 3);
montage(reshape(conv2Activations(:, :, :, sampleIdx), [14, 14]));
title('第二卷积层特征图');

%% 手写数字识别演示
% 创建一个简单的绘图界面,让用户手写数字
f = figure('Name', '手写数字识别演示', 'Position', [200, 200, 600, 500]);
ax = axes('Parent', f, 'Position', [0.1, 0.2, 0.8, 0.7]);
title('在下方区域手写一个数字');

% 创建绘图区域
drawingArea = uicontrol('Style', 'text', 'Position', [60, 100, 280, 280], ...
    'BackgroundColor', 'white');
axes('Position', [0.1, 0.2, 0.8, 0.7]);

% 初始化绘图数据
drawing = false;
lastPoint = [0, 0];
imgData = ones(280, 280) * 255; % 白色背景

% 鼠标回调函数
set(gcf, 'WindowButtonDownFcn', @startDrawing);
set(gcf, 'WindowButtonUpFcn', @stopDrawing);
set(gcf, 'WindowButtonMotionFcn', @draw);

% 创建按钮
uicontrol('Style', 'pushbutton', 'String', '识别', ...
    'Position', [100, 50, 100, 30], ...
    'Callback', @recognizeDigit);

uicontrol('Style', 'pushbutton', 'String', '清除', ...
    'Position', [220, 50, 100, 30], ...
    'Callback', @clearDrawing);

% 结果显示区域
resultText = uicontrol('Style', 'text', 'String', '结果将显示在这里', ...
    'Position', [100, 20, 200, 20], ...
    'FontSize', 12, 'FontWeight', 'bold');

%% 绘图回调函数
function startDrawing(~, ~)
    drawing = true;
end

function stopDrawing(~, ~)
    drawing = false;
    lastPoint = [0, 0];
end

function draw(~, ~)
    if drawing
        currentPoint = get(gca, 'CurrentPoint');
        x = round(currentPoint(1, 1));
        y = round(currentPoint(1, 2));
        
        % 确保坐标在绘图区域内
        if x >= 1 && x <= 280 && y >= 1 && y <= 280
            if lastPoint(1) > 0 && lastPoint(2) > 0
                % 在两点之间画线
                lineX = linspace(lastPoint(1), x, 50);
                lineY = linspace(lastPoint(2), y, 50);
                
                for k = 1:50
                    px = round(lineX(k));
                    py = round(lineY(k));
                    if px >= 1 && px <= 280 && py >= 1 && py <= 280
                        % 绘制粗线
                        for i = -2:2
                            for j = -2:2
                                if px+i > 0 && px+i <= 280 && py+j > 0 && py+j <= 280
                                    imgData(py+j, px+i) = 0; % 黑色
                                end
                            end
                        end
                    end
                end
            end
            
            % 更新图像
            imshow(imgData, 'Parent', gca);
            lastPoint = [x, y];
        end
    end
end

function recognizeDigit(~, ~)
    % 预处理用户绘制的图像
    userImg = imresize(imgData, [28, 28]);
    userImg = imcomplement(userImg); % 反转为黑底白字
    userImg = im2double(userImg);
    
    % 使用神经网络进行预测
    [predLabel, scores] = classify(net, userImg);
    
    % 显示结果
    set(resultText, 'String', sprintf('识别结果: %s (置信度: %.2f%%)', char(predLabel), max(scores)*100));
    
    % 显示处理后的图像
    figure('Name', '预处理后的手写数字');
    subplot(1, 2, 1);
    imshow(imcomplement(imgData)); % 原始手写图像
    title('用户手写数字');
    
    subplot(1, 2, 2);
    imshow(userImg);
    title('预处理后的图像');
end

function clearDrawing(~, ~)
    imgData = ones(280, 280) * 255; % 重置为白色背景
    imshow(imgData, 'Parent', gca);
    set(resultText, 'String', '结果将显示在这里');
end

系统功能与实现详解

1. 系统架构

本系统包含三个主要模块:

  • 模板匹配模块:创建数字模板并进行匹配识别
  • 神经网络模块:构建并训练卷积神经网络
  • 交互演示模块:允许用户手写数字进行实时识别

2. 数据处理流程

  1. 数据加载

    • 使用MATLAB自带的手写数字数据集
    • 包含0-9共10类数字图像
  2. 图像预处理

    % 转换为灰度图
    img = rgb2gray(img);
    
    % 调整大小为28x28像素
    img = imresize(img, [28, 28]);
    
    % 归一化处理
    img = im2double(img);
    
    % 图像二值化
    img = imbinarize(img);
    
  3. 模板创建

    • 对每个数字类别的图像求平均
    • 生成0-9的数字模板

3. 模板匹配算法

% 计算与每个模板的相关系数
correlations = zeros(1, 10);
for digit = 0:9
    corrMatrix = corrcoef(testImg(:), templates(:, :, digit+1)(:));
    correlations(digit+1) = corrMatrix(1, 2);
end

% 选择最相似的数字
[~, predLabel] = max(correlations);

参考源码 手写体识别 模板匹配识别方法 youwenfan.com/contentcsa/78091.html

4. 神经网络架构

本系统使用了一个高效的卷积神经网络结构:

层类型 参数设置 输出尺寸
输入层 28x28x1图像 28x28x1
卷积层1 5x5核, 32个滤波器 28x28x32
批量归一化层1 - 28x28x32
ReLU激活层1 - 28x28x32
最大池化层1 2x2池化, 步长2 14x14x32
卷积层2 3x3核, 64个滤波器 14x14x64
批量归一化层2 - 14x14x64
ReLU激活层2 - 14x14x64
最大池化层2 2x2池化, 步长2 7x7x64
全连接层1 128个神经元 128
ReLU激活层3 - 128
Dropout层 丢弃率40% 128
全连接层2 10个神经元 10
Softmax层 - 10
分类层 - -

5. 训练配置

options = trainingOptions('adam', ...
    'InitialLearnRate', 0.001, ...
    'MaxEpochs', 15, ...
    'MiniBatchSize', 128, ...
    'ValidationData', {XTest, YTest}, ...
    'Plots', 'training-progress');

6. 性能比较

方法 准确率 优点 缺点
模板匹配 75-85% 实现简单,计算快速 对形变和旋转敏感
神经网络 98-99% 鲁棒性强,识别精度高 需要大量数据和训练时间

7. 交互演示功能

系统提供了一个绘图界面:

  1. 用户在白色画布上手写数字
  2. 点击"识别"按钮进行预测
  3. 点击"清除"按钮重置画布
  4. 显示识别结果和置信度

关键技术与创新点

  1. 多方法融合

    • 同时实现模板匹配和神经网络两种方法
    • 提供性能对比分析
  2. 特征可视化

    • 展示卷积层提取的特征图
    • 帮助理解神经网络工作原理
  3. 交互式界面

    • 实时手写识别演示
    • 显示预处理过程和识别结果
  4. 全面的评估

    • 混淆矩阵分析
    • 错误分类可视化
    • 准确率对比

系统扩展建议

  1. 数据增强

    % 添加旋转、平移、缩放等变换
    augmenter = imageDataAugmenter(...
        'RandRotation', [-15, 15], ...
        'RandXTranslation', [-3, 3], ...
        'RandYTranslation', [-3, 3]);
    
  2. 迁移学习

    % 使用预训练的ResNet或MobileNet
    net = resnet50;
    lgraph = layerGraph(net);
    
  3. 模型优化

    • 添加注意力机制
    • 尝试不同的网络架构
    • 使用贝叶斯优化调整超参数
  4. 实时视频识别

    • 集成摄像头输入
    • 实现实时手写数字识别
  5. 移动端部署

    • 使用MATLAB Coder生成C++代码
    • 部署到移动设备或嵌入式系统

这个系统全面展示了手写数字识别的关键技术和实现方法,通过交互式界面增强了用户体验,适用于教育演示和实际应用开发。

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐