AI神经网络手写数字识别系统
snprintf(prediction_text, sizeof(prediction_text), "识别结果: %d", prediction);snprintf(prediction_text, sizeof(prediction_text), "画布已清除");fprintf(stderr, "SDL初始化失败: %s\n", SDL_GetError());fprintf(stderr,
这次为大家带来的是一个使用C语言实现的神经网络手写数字识别系统,结合了简单的图形界面。该程序实现了多层感知机(MLP)神经网络,能够识别用户手写的0-9数字。
一 编译以运行说明:
依赖项
1. SDL2 库 (用于图形界面)
2. C 编译器 (如 GCC)
### 编译命令 (Linux/macOS)
```bash
gcc neural_network.c -o neural_app -lSDL2 -lm
```
### 编译命令 (Windows)
```bash
gcc neural_network.c -o neural_app.exe -I"path_to_SDL2\include" -L"path_to_SDL2\lib" -lSDL2 -lm
```
### 运行程序
```bash
./neural_app
```
功能说明
这个C语言实现的神经网络应用具有以下功能:
1. 神经网络核心:
- 三层感知机结构(输入层784节点,隐藏层128节点,输出层10节点)
- 使用Sigmoid激活函数
- 支持前向传播和反向传播训练
2. 图形界面:
- 560x560像素绘图区域
- 鼠标绘制功能
- 实时结果显示区域
3. 交互功能:
- 鼠标绘制数字
- Enter键进行识别
- C键清除画布
- 实时显示识别结果
4. 训练数据:
- 使用简化版的MNIST数据集
- 包含1000个训练样本
- 每个样本为28x28像素的归一化图像
实际应用场景
这个手写数字识别系统可以应用于:
1. 邮政系统:自动识别邮政编码
2. 教育领域:批改数学作业中的手写数字
3. 金融系统:处理支票上的手写金额
4. 工业控制:读取仪表盘上的数字显示
5. 数据录入:将手写表格数字化
优化方向
1. 性能优化:
- 使用更高效的矩阵运算库
- 实现并行计算
- 使用SIMD指令优化
2. 准确率提升:
- 增加网络深度
- 实现卷积神经网络(CNN)
- 使用更高质量的训练数据
3. 功能扩展:
- 支持手写字母识别
- 实现简单数学表达式计算
- 添加模型保存/加载功能
4. 部署优化:
- 移植到嵌入式系统
- 开发移动端版本
- 创建Web服务接口
代码展示:
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <time.h>
#include <SDL2/SDL.h>
#define INPUT_SIZE 784 // 28x28像素
#define HIDDEN_SIZE 128 // 隐藏层神经元数量
#define OUTPUT_SIZE 10 // 输出层(0-9)
#define LEARNING_RATE 0.1
#define EPOCHS 5
#define TRAINING_SAMPLES 1000
// 神经网络结构
typedef struct {
double weights_ih[INPUT_SIZE][HIDDEN_SIZE];
double weights_ho[HIDDEN_SIZE][OUTPUT_SIZE];
double bias_h[HIDDEN_SIZE];
double bias_o[OUTPUT_SIZE];
} NeuralNetwork;
// 激活函数(sigmoid)
double sigmoid(double x) {
return 1.0 / (1.0 + exp(-x));
}
// sigmoid导数
double sigmoid_derivative(double x) {
return x * (1 - x);
}
// 初始化神经网络
void init_network(NeuralNetwork* net) {
srand(time(NULL));
// 初始化输入层到隐藏层的权重
for (int i = 0; i < INPUT_SIZE; i++) {
for (int j = 0; j < HIDDEN_SIZE; j++) {
net->weights_ih[i][j] = ((double)rand() / RAND_MAX) * 2 - 1; // [-1, 1]
}
}
// 初始化隐藏层到输出层的权重
for (int i = 0; i < HIDDEN_SIZE; i++) {
for (int j = 0; j < OUTPUT_SIZE; j++) {
net->weights_ho[i][j] = ((double)rand() / RAND_MAX) * 2 - 1;
}
}
// 初始化偏置
for (int i = 0; i < HIDDEN_SIZE; i++) {
net->bias_h[i] = ((double)rand() / RAND_MAX) * 2 - 1;
}
for (int i = 0; i < OUTPUT_SIZE; i++) {
net->bias_o[i] = ((double)rand() / RAND_MAX) * 2 - 1;
}
}
// 前向传播
void forward(NeuralNetwork* net, double* input, double* hidden, double* output) {
// 计算隐藏层
for (int j = 0; j < HIDDEN_SIZE; j++) {
hidden[j] = 0;
for (int i = 0; i < INPUT_SIZE; i++) {
hidden[j] += input[i] * net->weights_ih[i][j];
}
hidden[j] += net->bias_h[j];
hidden[j] = sigmoid(hidden[j]);
}
// 计算输出层
for (int j = 0; j < OUTPUT_SIZE; j++) {
output[j] = 0;
for (int i = 0; i < HIDDEN_SIZE; i++) {
output[j] += hidden[i] * net->weights_ho[i][j];
}
output[j] += net->bias_o[j];
output[j] = sigmoid(output[j]);
}
}
// 训练神经网络
void train(NeuralNetwork* net, double* input, double* target) {
double hidden[HIDDEN_SIZE] = {0};
double output[OUTPUT_SIZE] = {0};
// 前向传播
forward(net, input, hidden, output);
// 计算输出层误差
double output_error[OUTPUT_SIZE] = {0};
for (int i = 0; i < OUTPUT_SIZE; i++) {
output_error[i] = target[i] - output[i];
}
// 计算隐藏层误差
double hidden_error[HIDDEN_SIZE] = {0};
for (int i = 0; i < HIDDEN_SIZE; i++) {
double error = 0;
for (int j = 0; j < OUTPUT_SIZE; j++) {
error += output_error[j] * net->weights_ho[i][j];
}
hidden_error[i] = error * sigmoid_derivative(hidden[i]);
}
// 更新隐藏层到输出层的权重
for (int i = 0; i < HIDDEN_SIZE; i++) {
for (int j = 0; j < OUTPUT_SIZE; j++) {
net->weights_ho[i][j] += LEARNING_RATE * output_error[j] * sigmoid_derivative(output[j]) * hidden[i];
}
}
// 更新输入层到隐藏层的权重
for (int i = 0; i < INPUT_SIZE; i++) {
for (int j = 0; j < HIDDEN_SIZE; j++) {
net->weights_ih[i][j] += LEARNING_RATE * hidden_error[j] * input[i];
}
}
// 更新偏置
for (int i = 0; i < OUTPUT_SIZE; i++) {
net->bias_o[i] += LEARNING_RATE * output_error[i] * sigmoid_derivative(output[i]);
}
for (int i = 0; i < HIDDEN_SIZE; i++) {
net->bias_h[i] += LEARNING_RATE * hidden_error[i];
}
}
// 预测函数
int predict(NeuralNetwork* net, double* input) {
double hidden[HIDDEN_SIZE] = {0};
double output[OUTPUT_SIZE] = {0};
forward(net, input, hidden, output);
int max_index = 0;
double max_value = output[0];
for (int i = 1; i < OUTPUT_SIZE; i++) {
if (output[i] > max_value) {
max_value = output[i];
max_index = i;
}
}
return max_index;
}
// 加载MNIST数据(简化版)
void load_mnist_data(double inputs[TRAINING_SAMPLES][INPUT_SIZE],
int labels[TRAINING_SAMPLES]) {
// 在实际应用中,这里会从文件加载真实数据
// 这里使用随机数据作为示例
for (int i = 0; i < TRAINING_SAMPLES; i++) {
labels[i] = rand() % 10;
for (int j = 0; j < INPUT_SIZE; j++) {
// 生成类似手写数字的随机数据
inputs[i][j] = (rand() % 100) / 100.0;
if (j % 28 > 10 && j % 28 < 18 && j / 28 > 5 && j / 28 < 23) {
inputs[i][j] = (rand() % 30 + 70) / 100.0;
}
}
}
}
// 初始化SDL
SDL_Window* init_sdl() {
if (SDL_Init(SDL_INIT_VIDEO) != 0) {
fprintf(stderr, "SDL初始化失败: %s\n", SDL_GetError());
return NULL;
}
SDL_Window* window = SDL_CreateWindow("神经网络手写数字识别 - C语言实现",
SDL_WINDOWPOS_CENTERED,
SDL_WINDOWPOS_CENTERED,
560, 560,
SDL_WINDOW_SHOWN);
if (!window) {
fprintf(stderr, "窗口创建失败: %s\n", SDL_GetError());
SDL_Quit();
return NULL;
}
return window;
}
// 主函数
int main() {
// 初始化神经网络
NeuralNetwork net;
init_network(&net);
// 加载训练数据
double inputs[TRAINING_SAMPLES][INPUT_SIZE];
int labels[TRAINING_SAMPLES];
load_mnist_data(inputs, labels);
printf("开始训练神经网络...\n");
// 训练神经网络
for (int epoch = 0; epoch < EPOCHS; epoch++) {
printf("训练周期 %d/%d\n", epoch + 1, EPOCHS);
for (int i = 0; i < TRAINING_SAMPLES; i++) {
double target[OUTPUT_SIZE] = {0};
target[labels[i]] = 1.0;
train(&net, inputs[i], target);
}
}
printf("训练完成!\n");
// 初始化SDL和窗口
SDL_Window* window = init_sdl();
if (!window) {
return 1;
}
SDL_Renderer* renderer = SDL_CreateRenderer(window, -1,
SDL_RENDERER_ACCELERATED |
SDL_RENDERER_PRESENTVSYNC);
if (!renderer) {
fprintf(stderr, "渲染器创建失败: %s\n", SDL_GetError());
SDL_DestroyWindow(window);
SDL_Quit();
return 1;
}
// 创建绘制表面
SDL_Surface* drawing_surface = SDL_CreateRGBSurface(0, 560, 560, 32,
0x00FF0000, 0x0000FF00,
0x000000FF, 0xFF000000);
if (!drawing_surface) {
fprintf(stderr, "表面创建失败: %s\n", SDL_GetError());
SDL_DestroyRenderer(renderer);
SDL_DestroyWindow(window);
SDL_Quit();
return 1;
}
// 填充黑色背景
SDL_FillRect(drawing_surface, NULL, SDL_MapRGB(drawing_surface->format, 0, 0, 0));
// 创建纹理
SDL_Texture* texture = SDL_CreateTextureFromSurface(renderer, drawing_surface);
if (!texture) {
fprintf(stderr, "纹理创建失败: %s\n", SDL_GetError());
SDL_FreeSurface(drawing_surface);
SDL_DestroyRenderer(renderer);
SDL_DestroyWindow(window);
SDL_Quit();
return 1;
}
// 主循环
int quit = 0;
SDL_Event event;
int drawing = 0;
char prediction_text[50] = "请手写一个数字 (0-9)";
while (!quit) {
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
quit = 1;
break;
case SDL_MOUSEBUTTONDOWN:
if (event.button.button == SDL_BUTTON_LEFT) {
drawing = 1;
}
break;
case SDL_MOUSEBUTTONUP:
if (event.button.button == SDL_BUTTON_LEFT) {
drawing = 0;
}
break;
case SDL_MOUSEMOTION:
if (drawing) {
int x = event.motion.x;
int y = event.motion.y;
// 绘制圆形点
for (int dy = -10; dy <= 10; dy++) {
for (int dx = -10; dx <= 10; dx++) {
int px = x + dx;
int py = y + dy;
if (px >= 0 && px < 560 && py >= 0 && py < 560) {
if (dx*dx + dy*dy <= 100) { // 半径为10的圆
Uint32* pixels = (Uint32*)drawing_surface->pixels;
pixels[py * drawing_surface->w + px] =
SDL_MapRGB(drawing_surface->format, 255, 255, 255);
}
}
}
}
// 更新纹理
SDL_UpdateTexture(texture, NULL, drawing_surface->pixels, drawing_surface->pitch);
}
break;
case SDL_KEYDOWN:
if (event.key.keysym.sym == SDLK_c) {
// 清除画布
SDL_FillRect(drawing_surface, NULL, SDL_MapRGB(drawing_surface->format, 0, 0, 0));
SDL_UpdateTexture(texture, NULL, drawing_surface->pixels, drawing_surface->pitch);
snprintf(prediction_text, sizeof(prediction_text), "画布已清除");
} else if (event.key.keysym.sym == SDLK_RETURN) {
// 识别数字
double input[INPUT_SIZE] = {0};
// 将560x560图像缩小为28x28
for (int y = 0; y < 28; y++) {
for (int x = 0; x < 28; x++) {
double sum = 0;
for (int dy = 0; dy < 20; dy++) {
for (int dx = 0; dx < 20; dx++) {
int px = x * 20 + dx;
int py = y * 20 + dy;
if (px < 560 && py < 560) {
Uint32 pixel = ((Uint32*)drawing_surface->pixels)[py * 560 + px];
Uint8 r, g, b;
SDL_GetRGB(pixel, drawing_surface->format, &r, &g, &b);
if (r > 128) {
sum += 1.0;
}
}
}
}
input[y * 28 + x] = sum / 400.0; // 归一化
}
}
// 进行预测
int prediction = predict(&net, input);
snprintf(prediction_text, sizeof(prediction_text), "识别结果: %d", prediction);
}
break;
}
}
// 渲染
SDL_SetRenderDrawColor(renderer, 0, 0, 0, 255);
SDL_RenderClear(renderer);
// 渲染绘图区域
SDL_Rect drawing_rect = {0, 0, 560, 560};
SDL_RenderCopy(renderer, texture, NULL, &drawing_rect);
// 渲染识别结果
SDL_Color text_color = {255, 255, 255, 255};
SDL_Surface* text_surface = SDL_CreateRGBSurface(0, 300, 40, 32, 0, 0, 0, 0);
SDL_FillRect(text_surface, NULL, SDL_MapRGB(text_surface->format, 50, 50, 150));
// 简单文本渲染(实际应用中应使用SDL_ttf)
SDL_Rect text_rect = {10, 10, 280, 30};
for (int i = 0; prediction_text[i] != '\0'; i++) {
// 简化版字符渲染
SDL_Rect char_rect = {text_rect.x + i * 14, text_rect.y, 12, 20};
SDL_FillRect(text_surface, &char_rect, SDL_MapRGB(text_surface->format, 255, 255, 255));
}
SDL_Texture* text_texture = SDL_CreateTextureFromSurface(renderer, text_surface);
SDL_RenderCopy(renderer, text_texture, NULL, &text_rect);
SDL_FreeSurface(text_surface);
SDL_DestroyTexture(text_texture);
// 渲染帮助文本
SDL_Rect help_rect = {10, 520, 540, 30};
SDL_Surface* help_surface = SDL_CreateRGBSurface(0, 540, 30, 32, 0, 0, 0, 0);
SDL_FillRect(help_surface, NULL, SDL_MapRGB(help_surface->format, 50, 150, 50));
// 简单文本渲染
const char* help_text = "按Enter识别 | 按C清除 | 鼠标绘制";
for (int i = 0; help_text[i] != '\0'; i++) {
SDL_Rect char_rect = {10 + i * 14, 10, 12, 20};
SDL_FillRect(help_surface, &char_rect, SDL_MapRGB(help_surface->format, 255, 255, 255));
}
text_texture = SDL_CreateTextureFromSurface(renderer, help_surface);
SDL_RenderCopy(renderer, text_texture, NULL, &help_rect);
SDL_FreeSurface(help_surface);
SDL_DestroyTexture(text_texture);
SDL_RenderPresent(renderer);
}
// 清理资源
SDL_DestroyTexture(texture);
SDL_FreeSurface(drawing_surface);
SDL_DestroyRenderer(renderer);
SDL_DestroyWindow(window);
SDL_Quit();
return 0;
}
总结:使用C语言实现展示了神经网络的基本原理和实际应用,虽然相比Python实现更底层,但在资源受限的环境中具有明显优势。
更多推荐
所有评论(0)