DJL——java深度学习框架学习笔记——自定义数据集
自定义数据集作为java深度学习框架,进行深度学习的时候,首先重要的是数据集,只有有了数据,才可以对自己的模型进行训练。我这里采用的是人脸检测,这是数据集 CelebA 点击下载DJL基础依赖<dependency><groupId>ai.djl</groupId><artifactId>api</artifactId><versi
·
自定义数据集
作为java深度学习框架,进行深度学习的时候,首先重要的是数据集,只有有了数据,才可以对自己的模型进行训练。
我这里采用的是人脸检测,这是数据集 CelebA 点击下载
DJL基础依赖
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.9.0</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
<version>0.9.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-model-zoo</artifactId>
<version>0.9.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.9.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-auto</artifactId>
<version>1.7.1</version>
</dependency>
这是DJL的基础库依赖,只有有了这些,我们才可以进行深度学习,底层为aws开发的C++,和C的调用,但是由于是aws维护,拿来用,它不香吗?
数据集创建
这是官方的介绍
DJL中的数据集代表原始数据和加载过程。RandomAccessDataset实现了Dataset接口,并提供了全面的数据加载功能。RandomAccessDataset还是支持使用索引对数据进行随机访问的基本数据集。您可以通过扩展RandomAccessDataset轻松自定义自己的数据集
官网地址
我这里介绍的主要是创建自定义数据集
创建自定义数据集,需要集成RandomAccessDataset对象,重写他的一些方法
废话不多说,直接上代码
package com.face.demo.utlis;
import ai.djl.Application;
import ai.djl.basicdataset.BasicDatasets;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.repository.Resource;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Record;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Transform;
import ai.djl.translate.TranslateException;
import ai.djl.util.JsonUtils;
import ai.djl.util.Progress;
import cn.hutool.core.io.file.FileReader;
import com.face.demo.pojo.FaceInfo;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.reflect.TypeToken;
import com.sun.imageio.plugins.common.ImageUtil;
import org.apache.commons.csv.CSVRecord;
import java.io.*;
import java.lang.reflect.Type;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
public class FaceDataSet extends RandomAccessDataset {
private static final String VERSION = "1.0";
private static final String ARTIFACT_ID = "banana";
private final Usage usage;
private final Image.Flag flag;
private final List<Path> imagePaths;
private final List<float[]> labels;
private final Resource resource;
private boolean prepared;
public FaceDataSet(FaceDataSet.Builder builder) {
super(builder);
this.usage = builder.usage;
this.flag = builder.flag;
this.imagePaths = new ArrayList();
this.labels = new ArrayList();
MRL mrl = MRL.dataset(Application.CV.ANY, builder.groupId, builder.artifactId);
this.resource = new Resource(builder.repository, mrl, "1.0");
}
public static FaceDataSet.Builder builder() {
return new FaceDataSet.Builder();
}
@Override
public Record get(NDManager manager, long index) throws IOException {
int idx = Math.toIntExact(index);
NDList d = new NDList(new NDArray[]{ImageFactory.getInstance().fromFile((Path) this.imagePaths.get(idx)).toNDArray(manager, this.flag)});
NDArray label = manager.create((float[]) this.labels.get(idx));
NDList l = new NDList(new NDArray[]{label.reshape((new Shape(new long[]{1L})).addAll(label.getShape()))});
return new Record(d, l);
}
@Override
protected long availableSize() {
return (long) this.imagePaths.size();
}
@Override
public void prepare(Progress progress) throws IOException, TranslateException {
if (!this.prepared) {
Path usagePath = Paths.get("C:\\Users\\mzp\\Documents\\img_celeba.7z\\img_celeba\\img_celeba");
FileReader fileReader = new FileReader("C:\\Users\\mzp\\Documents\\Anno\\list_bbox_celeba.txt");
List<String> strings = fileReader.readLines();
strings.remove(0);
strings.remove(0);
strings.forEach((s) -> {
String[] s1 = s.split("\\s+");
FaceInfo faceInfo = new FaceInfo(s1);
float[] labelArray = new float[5];
labelArray[0] = 0.0f;
float[] normalized = Normalized(faceInfo);
labelArray[1] = (Float) normalized[0];
labelArray[2] = (Float) normalized[1];
labelArray[3] = (Float) normalized[2];
labelArray[4] = (Float) normalized[3];
this.imagePaths.add(usagePath.resolve(faceInfo.getImage_id()));
this.labels.add(labelArray);
});
this.prepared = true;
}
}
public static final class Builder extends BaseBuilder<FaceDataSet.Builder> {
Repository repository;
String groupId;
String artifactId;
Usage usage;
Image.Flag flag;
Builder() {
this.repository = BasicDatasets.REPOSITORY;
this.groupId = "ai.djl.basicdataset";
this.artifactId = "face";
this.usage = Usage.TRAIN;
this.flag = Image.Flag.COLOR;
}
public FaceDataSet.Builder self() {
return this;
}
public FaceDataSet.Builder optUsage(Usage usage) {
this.usage = usage;
return this.self();
}
public FaceDataSet.Builder optRepository(Repository repository) {
this.repository = repository;
return this.self();
}
public FaceDataSet.Builder optGroupId(String groupId) {
this.groupId = groupId;
return this;
}
public FaceDataSet.Builder optArtifactId(String artifactId) {
if (artifactId.contains(":")) {
String[] tokens = artifactId.split(":");
this.groupId = tokens[0];
this.artifactId = tokens[1];
} else {
this.artifactId = artifactId;
}
return this;
}
public FaceDataSet.Builder optFlag(Image.Flag flag) {
this.flag = flag;
return this.self();
}
public FaceDataSet build() {
if (this.pipeline == null) {
this.pipeline = new Pipeline(new Transform[]{new ToTensor()});
}
return new FaceDataSet(this);
}
}
public float[] Normalized(FaceInfo faceInfo) {
File file = new File(faceInfo.getImageURL());
try {
FileInputStream fileInputStream = new FileInputStream(file);
Image image = ImageFactory.getInstance().fromInputStream(fileInputStream);
float dw = 1.f / image.getWidth();
float dh = 1.f / image.getHeight();
float x_1 = Float.parseFloat(faceInfo.getX_1());
float y_1 = Float.parseFloat(faceInfo.getY_1());
float width = Float.parseFloat(faceInfo.getWidth());
float height = Float.parseFloat(faceInfo.getHeight());
float x = (x_1 + y_1) / 2.0f;
float y = (width + height) / 2.0f;
float w = y_1 - x_1;
float h = height - width;
x = x * dw;
w = w * dw;
y = y * dh;
h = h * dh;
float[] floats = new float[4];
floats[0] = x;
floats[1] = w;
floats[2] = y;
floats[3] = h;
return floats;
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
return null;
}
}
更多推荐
所有评论(0)