PyTorch使用Dataset和DataLoader加载数据集
在PyTorch里优化器都是小批量小批量地优化训练的,即每次都会从原数据集中取出一小批量进行训练,完成一次权重更新后,再从原数据集中取下一个小批量数据,然后再训练再更新。。。比如最常用的小批量随机梯度下降(Mini-Batch Gradient Descent,MBGD)。毕竟原数据集往往很大,不可能一次性的全部载入内存,只能一小批一小批地载入内存。训练完了就扔了,再加载下一小批。如何实现批量地加
·
在深度学习中训练模型都是小批量小批量地优化训练的,即每次都会从原数据集中取出一小批量进行训练,完成一次权重更新后,再从原数据集中取下一个小批量数据,然后再训练再更新。比如最常用的小批量随机梯度下降(Mini-Batch Gradient Descent,MBGD)。
毕竟原数据集往往很大,不可能一次性的全部载入模型,只能一小批一小批地载入。训练完了就扔了,再加载下一小批。
在PyTorch的torch.utils.data
包中定义了两个类Dataset
和DataLoader
,这两个类就是用来批量地加载数据的。
下面说一下其用法:
1、准备数据
在使用Dataset和DataLoader之前需要先准备好数据,这里随即构造了一段数据:
# 自己编造一个数据集
import pandas as pd
import numpy as np
data = np.random.rand(128,3)
data = pd.DataFrame(data, columns=['feature_1', 'feature_2', 'label'])
数据形式如下:
或者,如果你有数据,可以这样读取:
data = pd.read_csv('data/diabetes.csv') # 'data/diabetes.csv' 是我的数据的路径
数据形式如下:
2、写一个简单的数据加载器
import numpy as np
import pandas as pd
import torch
# utils是工具包
from torch.utils.data import Dataset # Dataset是个抽象类,只能用于继承
from torch.utils.data import DataLoader # DataLoader需实例化,用于加载数据
class MyDataset(Dataset): # 继承Dataset类
def __init__(self, df):
# 把数据和标签拿出来
self.x_data = df[['feature_1', 'feature_2']].values
self.y_data = df[['label']].values
# 数据集的长度
self.length = len(self.y_data)
# 下面两个魔术方法比较好写,直接照着这个格式写就行了
def __getitem__(self, index): # 参数index必写
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.length # 只需返回数据集的长度即可
# 实例化
my_dataset = MyDataset(data)
train_loader = DataLoader(dataset=my_dataset, # 要传递的数据集
batch_size=32, #一个小批量数据的大小是多少
shuffle=True, # 数据集顺序是否要打乱,一般是要的。测试数据集一般没必要
num_workers=0) # 需要几个进程来一次性读取这个小批量数据,默认0,一般用0就够了,多了有时会出一些底层错误。
3、如何使用数据加载器来训练模型?
类似于迭代器的使用
for epoch in range(100):
# ---------------主要看这两行代码------------------
for i, data in enumerate(train_loader):
# 1. 数据准备
inputs, labels = data
# ---------------主要看这两行代码------------------
# 2. 前向传播
y_pred = model(inputs)
loss = criterion(y_pred, labels)
# 3. 反向传播
loss.backward()
# 4. 权重/模型更新
optimizer.step()
# 5. 梯度清零
optimizer.zero_grad()
总结模板
模板如下:
class MyDataset(Dataset):
def __init__(self):
'''
有两种写法:
1、将全部数据都加载进内存里,适用于少量数据(上面那个例子就是全部加载);
2、当数据量或者标签量很大时,比如图片,就把这些数据或者标签放到文件或数据库里去,只需在此方法中初始化定义这些文件索引的列表即可。
'''
pass
# 以下2个方法都是魔法方法
def __getitem__(self, index): # 表示将来实例化这个对象后,它能支持下标(索引)操作,也就是能通过索引把里面的数据拿出来。
pass
def __len__(self): # 返回数据集的长度
pass
my_dataset = MyDataset()
train_loader = DataLoader(dataset=my_dataset, # 传递数据集
batch_size=32, #一个小批量容量是多少
shuffle=True, # 数据集顺序是否要打乱,一般是要的。测试数据集一般没必要
num_workers=0) # 需要几个进程来一次性读取这个小批量数据
更多推荐
所有评论(0)