步骤1:准备MNIST数据集进行训练

   /**
         * 步骤1:准备MNIST数据集进行训练
         * batchSize:每个批次有多少元素,batchSize通常是适合内存的2的最大幂。
         * true:随机打乱批次的元素。
         */
        int batchSize = 32;
        Mnist mnist = Mnist.builder().setSampling(batchSize, true).build();
        try {
            mnist.prepare(new ProgressBar());
        } catch (IOException e) {
            e.printStackTrace();
        }

步骤2:创建模型

/**
         * 步骤2:创建模型
         * 建立一个模型包含要使用的输入、输出、形状和数据类型的附加信息。
         * MNIST数据集中的图像是28x28灰度图像,所以我们将创建一个具有28 x 28输入的MLP块。
         * 每个图像可能有10个类(0到9),输出将是10。
         * 对于隐藏层,选择了新的int[]{128,64}。
         */
        Model model = Model.newInstance("mlp");
        model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));

步骤3:创建一名培训师,设置培训配置

 /**
         * 步骤3:创建一名培训师,设置培训配置
         * 损失函数:损失函数用于衡量我们的模型与数据集的匹配程度。因为函数的值越低越好,所以它被称为“损失”函数。损失是该模型唯一必需的参数
         * 评估器函数:评估器函数还用于衡量我们的模型与数据集的匹配程度。与损失不同,它们只是供人们查看,而不是用于优化模型。由于许多损失并不是那么直观,添加其他评估者(如Accuracy)可以帮助了解模型的运行情况。
         * 培训监听器:培训监听器通过监听器界面为培训过程添加了额外的功能。这可能包括显示训练进度、在训练未定义时提前停止训练或记录表现指标。
         */
        DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())//softmaxCrossEntropyLoss是分类问题的标准损失
                .addEvaluator(new Accuracy()) // 评估函数
                .addTrainingListeners(TrainingListener.Defaults.logging());//培训监听器
        Trainer trainer = model.newTrainer(config);//创建训练器
        /**

步骤5:初始化训练

/**
         * 步骤5:初始化训练
         * 输入形状的第一个轴是批量大小:1。
         * MLP输入形状的第二个轴是输入图像中的像素数28 * 28。
         */
        trainer.initialize(new Shape(1, 28 * 28));

步骤6:训练模型

 /**
         * 步骤6:训练模型
         */
        int epoch = 2;
        try {
            EasyTrain.fit(trainer, epoch, mnist, null);
        } catch (IOException e) {
            e.printStackTrace();
        } catch (TranslateException e) {
            e.printStackTrace();
        }

步骤7:保存模型

/**
         * 步骤7:保存模型
         */
        Path modelDir = Paths.get("D:\\models");
        try {
            Files.createDirectories(modelDir);
            model.setProperty("Epoch", String.valueOf(epoch));
            model.save(modelDir, "mlp");
        } catch (IOException e) {
            e.printStackTrace();
        }
        System.out.println(model.toString());

整合后JAVA文件

package com.lihao;
import java.io.IOException;
import java.nio.file.*;
import ai.djl.*;
import ai.djl.basicdataset.cv.classification.Mnist;
import ai.djl.ndarray.types.*;
import ai.djl.training.*;
import ai.djl.training.dataset.*;
import ai.djl.training.initializer.*;
import ai.djl.training.loss.*;
import ai.djl.training.listener.*;
import ai.djl.training.evaluator.*;
import ai.djl.training.optimizer.*;
import ai.djl.training.util.*;
import ai.djl.basicmodelzoo.cv.classification.*;
import ai.djl.basicmodelzoo.basic.*;
import ai.djl.translate.TranslateException;

/**
 * 训练模型
 * %maven ai.djl:api:0.23.0
 * %maven ai.djl:basicdataset:0.23.0
 * %maven ai.djl:model-zoo:0.23.0
 * %maven ai.djl.mxnet:mxnet-engine:0.23.0
 */
public class DjlTrainModel {
    public static void main(String[] args) {
        /**
         * 步骤1:准备MNIST数据集进行训练
         * batchSize:每个批次有多少元素,batchSize通常是适合内存的2的最大幂。
         * true:随机打乱批次的元素。
         */
        int batchSize = 32;
        Mnist mnist = Mnist.builder().setSampling(batchSize, true).build();
        try {
            mnist.prepare(new ProgressBar());
        } catch (IOException e) {
            e.printStackTrace();
        }
        /**
         * 步骤2:创建模型
         * 建立一个模型包含要使用的输入、输出、形状和数据类型的附加信息。
         * MNIST数据集中的图像是28x28灰度图像,所以我们将创建一个具有28 x 28输入的MLP块。
         * 每个图像可能有10个类(0到9),输出将是10。
         * 对于隐藏层,选择了新的int[]{128,64}。
         */
        Model model = Model.newInstance("mlp");
        model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));
        /**
         * 步骤3:创建一名培训师,设置培训配置
         * 损失函数:损失函数用于衡量我们的模型与数据集的匹配程度。因为函数的值越低越好,所以它被称为“损失”函数。损失是该模型唯一必需的参数
         * 评估器函数:评估器函数还用于衡量我们的模型与数据集的匹配程度。与损失不同,它们只是供人们查看,而不是用于优化模型。由于许多损失并不是那么直观,添加其他评估者(如Accuracy)可以帮助了解模型的运行情况。
         * 培训监听器:培训监听器通过监听器界面为培训过程添加了额外的功能。这可能包括显示训练进度、在训练未定义时提前停止训练或记录表现指标。
         */
        DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())//softmaxCrossEntropyLoss是分类问题的标准损失
                .addEvaluator(new Accuracy()) // 评估函数
                .addTrainingListeners(TrainingListener.Defaults.logging());//培训监听器
        Trainer trainer = model.newTrainer(config);//创建训练器
        /**
         * 步骤5:初始化训练
         * 输入形状的第一个轴是批量大小:1。
         * MLP输入形状的第二个轴是输入图像中的像素数28 * 28。
         */
        trainer.initialize(new Shape(1, 28 * 28));
        /**
         * 步骤6:训练模型
         */
        int epoch = 2;
        try {
            EasyTrain.fit(trainer, epoch, mnist, null);
        } catch (IOException e) {
            e.printStackTrace();
        } catch (TranslateException e) {
            e.printStackTrace();
        }
        /**
         * 步骤7:保存模型
         */
        Path modelDir = Paths.get("D:\\models");
        try {
            Files.createDirectories(modelDir);
            model.setProperty("Epoch", String.valueOf(epoch));
            model.save(modelDir, "mlp");
        } catch (IOException e) {
            e.printStackTrace();
        }
        System.out.println(model.toString());

    }
}

POM文件

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.lihao</groupId>
    <artifactId>djl</artifactId>
    <version>1.0-SNAPSHOT</version>
    <packaging>jar</packaging>

    <name>Spring Boot Blank Project (from https://github.com/making/spring-boot-blank)</name>

    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>2.7.12</version>
    </parent>

    <properties>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
        <start-class>com.lihao.App</start-class>
        <java.version>1.8</java.version>
    </properties>

    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-thymeleaf</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>0.23.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
            <version>0.23.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>model-zoo</artifactId>
            <version>0.23.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-engine</artifactId>
            <version>0.23.0</version>
        </dependency>
    </dependencies>
    <build>
        <finalName>djl</finalName>
        <plugins>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <configuration>
                    <source>1.8</source>
                    <target>1.8</target>
                </configuration>
            </plugin>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
                <version>2.6.0</version>
            </plugin>
        </plugins>
    </build>

</project>

运行结果
在这里插入图片描述

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐