
Deep Java Library(一)模型训练
Deep Java Library深度学习
·
步骤1:准备MNIST数据集进行训练
/**
* 步骤1:准备MNIST数据集进行训练
* batchSize:每个批次有多少元素,batchSize通常是适合内存的2的最大幂。
* true:随机打乱批次的元素。
*/
int batchSize = 32;
Mnist mnist = Mnist.builder().setSampling(batchSize, true).build();
try {
mnist.prepare(new ProgressBar());
} catch (IOException e) {
e.printStackTrace();
}
步骤2:创建模型
/**
* 步骤2:创建模型
* 建立一个模型包含要使用的输入、输出、形状和数据类型的附加信息。
* MNIST数据集中的图像是28x28灰度图像,所以我们将创建一个具有28 x 28输入的MLP块。
* 每个图像可能有10个类(0到9),输出将是10。
* 对于隐藏层,选择了新的int[]{128,64}。
*/
Model model = Model.newInstance("mlp");
model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));
步骤3:创建一名培训师,设置培训配置
/**
* 步骤3:创建一名培训师,设置培训配置
* 损失函数:损失函数用于衡量我们的模型与数据集的匹配程度。因为函数的值越低越好,所以它被称为“损失”函数。损失是该模型唯一必需的参数
* 评估器函数:评估器函数还用于衡量我们的模型与数据集的匹配程度。与损失不同,它们只是供人们查看,而不是用于优化模型。由于许多损失并不是那么直观,添加其他评估者(如Accuracy)可以帮助了解模型的运行情况。
* 培训监听器:培训监听器通过监听器界面为培训过程添加了额外的功能。这可能包括显示训练进度、在训练未定义时提前停止训练或记录表现指标。
*/
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())//softmaxCrossEntropyLoss是分类问题的标准损失
.addEvaluator(new Accuracy()) // 评估函数
.addTrainingListeners(TrainingListener.Defaults.logging());//培训监听器
Trainer trainer = model.newTrainer(config);//创建训练器
/**
步骤5:初始化训练
/**
* 步骤5:初始化训练
* 输入形状的第一个轴是批量大小:1。
* MLP输入形状的第二个轴是输入图像中的像素数28 * 28。
*/
trainer.initialize(new Shape(1, 28 * 28));
步骤6:训练模型
/**
* 步骤6:训练模型
*/
int epoch = 2;
try {
EasyTrain.fit(trainer, epoch, mnist, null);
} catch (IOException e) {
e.printStackTrace();
} catch (TranslateException e) {
e.printStackTrace();
}
步骤7:保存模型
/**
* 步骤7:保存模型
*/
Path modelDir = Paths.get("D:\\models");
try {
Files.createDirectories(modelDir);
model.setProperty("Epoch", String.valueOf(epoch));
model.save(modelDir, "mlp");
} catch (IOException e) {
e.printStackTrace();
}
System.out.println(model.toString());
整合后JAVA文件
package com.lihao;
import java.io.IOException;
import java.nio.file.*;
import ai.djl.*;
import ai.djl.basicdataset.cv.classification.Mnist;
import ai.djl.ndarray.types.*;
import ai.djl.training.*;
import ai.djl.training.dataset.*;
import ai.djl.training.initializer.*;
import ai.djl.training.loss.*;
import ai.djl.training.listener.*;
import ai.djl.training.evaluator.*;
import ai.djl.training.optimizer.*;
import ai.djl.training.util.*;
import ai.djl.basicmodelzoo.cv.classification.*;
import ai.djl.basicmodelzoo.basic.*;
import ai.djl.translate.TranslateException;
/**
* 训练模型
* %maven ai.djl:api:0.23.0
* %maven ai.djl:basicdataset:0.23.0
* %maven ai.djl:model-zoo:0.23.0
* %maven ai.djl.mxnet:mxnet-engine:0.23.0
*/
public class DjlTrainModel {
public static void main(String[] args) {
/**
* 步骤1:准备MNIST数据集进行训练
* batchSize:每个批次有多少元素,batchSize通常是适合内存的2的最大幂。
* true:随机打乱批次的元素。
*/
int batchSize = 32;
Mnist mnist = Mnist.builder().setSampling(batchSize, true).build();
try {
mnist.prepare(new ProgressBar());
} catch (IOException e) {
e.printStackTrace();
}
/**
* 步骤2:创建模型
* 建立一个模型包含要使用的输入、输出、形状和数据类型的附加信息。
* MNIST数据集中的图像是28x28灰度图像,所以我们将创建一个具有28 x 28输入的MLP块。
* 每个图像可能有10个类(0到9),输出将是10。
* 对于隐藏层,选择了新的int[]{128,64}。
*/
Model model = Model.newInstance("mlp");
model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));
/**
* 步骤3:创建一名培训师,设置培训配置
* 损失函数:损失函数用于衡量我们的模型与数据集的匹配程度。因为函数的值越低越好,所以它被称为“损失”函数。损失是该模型唯一必需的参数
* 评估器函数:评估器函数还用于衡量我们的模型与数据集的匹配程度。与损失不同,它们只是供人们查看,而不是用于优化模型。由于许多损失并不是那么直观,添加其他评估者(如Accuracy)可以帮助了解模型的运行情况。
* 培训监听器:培训监听器通过监听器界面为培训过程添加了额外的功能。这可能包括显示训练进度、在训练未定义时提前停止训练或记录表现指标。
*/
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())//softmaxCrossEntropyLoss是分类问题的标准损失
.addEvaluator(new Accuracy()) // 评估函数
.addTrainingListeners(TrainingListener.Defaults.logging());//培训监听器
Trainer trainer = model.newTrainer(config);//创建训练器
/**
* 步骤5:初始化训练
* 输入形状的第一个轴是批量大小:1。
* MLP输入形状的第二个轴是输入图像中的像素数28 * 28。
*/
trainer.initialize(new Shape(1, 28 * 28));
/**
* 步骤6:训练模型
*/
int epoch = 2;
try {
EasyTrain.fit(trainer, epoch, mnist, null);
} catch (IOException e) {
e.printStackTrace();
} catch (TranslateException e) {
e.printStackTrace();
}
/**
* 步骤7:保存模型
*/
Path modelDir = Paths.get("D:\\models");
try {
Files.createDirectories(modelDir);
model.setProperty("Epoch", String.valueOf(epoch));
model.save(modelDir, "mlp");
} catch (IOException e) {
e.printStackTrace();
}
System.out.println(model.toString());
}
}
POM文件
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.lihao</groupId>
<artifactId>djl</artifactId>
<version>1.0-SNAPSHOT</version>
<packaging>jar</packaging>
<name>Spring Boot Blank Project (from https://github.com/making/spring-boot-blank)</name>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>2.7.12</version>
</parent>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<start-class>com.lihao.App</start-class>
<java.version>1.8</java.version>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-thymeleaf</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.23.0</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
<version>0.23.0</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
<version>0.23.0</version>
</dependency>
<dependency>
<groupId>ai.djl.mxnet</groupId>
<artifactId>mxnet-engine</artifactId>
<version>0.23.0</version>
</dependency>
</dependencies>
<build>
<finalName>djl</finalName>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<version>2.6.0</version>
</plugin>
</plugins>
</build>
</project>
运行结果
更多推荐
所有评论(0)