【Torchvision】torchvision.datasets 模块:提供标准化的计算机视觉数据集
torchvision.datasets 是 PyTorch 中处理计算机视觉数据的核心模块,提供多种预定义数据集(如 MNIST、CIFAR、ImageNet 等),支持图像分类、目标检测和分割任务。该模块支持自动下载、数据变换,并能与 DataLoader 无缝配合,实现批处理和多线程加载。用户也可通过继承 Dataset 类创建自定义数据集。标准数据集通常包含 root、train、down
torchvision.datasets 是 torchvision 库中的一个核心模块,专门用于提供标准化的计算机视觉数据集。这些数据集可以直接用于深度学习任务,如图像分类、目标检测、图像分割等。torchvision.datasets 中的每个数据集类都是 PyTorch 的 torch.utils.data.Dataset 类的子类,兼容 PyTorch 的 DataLoader,支持灵活的数据加载、预处理和数据增强。
以下是对 torchvision.datasets 模块的详尽介绍,包括其功能、常用数据集、用法、参数、自定义数据集支持,以及代码示例和注意事项。
1. torchvision.datasets 模块概述
torchvision.datasets 提供了一系列预定义的计算机视觉数据集,涵盖了图像分类、目标检测、图像分割、视频分类等任务。这些数据集通常包含图像数据和对应的标签或注释,适合用于模型训练、验证和测试。
主要特点:
- 标准化接口:所有数据集类都实现了
__getitem__和__len__方法,符合 PyTorch 的Dataset接口,可以无缝与DataLoader结合。 - 内置下载功能:大部分数据集支持自动下载(通过设置
download=True)。 - 灵活的变换支持:通过
transform和target_transform参数,可以对图像和标签应用预处理或数据增强。 - 多样化的数据集:涵盖小型数据集(如 MNIST、CIFAR-10)到大规模数据集(如 ImageNet、COCO)。
2. 常用数据集
torchvision.datasets 提供了多种数据集,适用于不同的计算机视觉任务。以下是常用的数据集及其特点:
(1) 图像分类数据集
这些数据集主要用于图像分类任务,包含图像和对应的类别标签。
-
MNIST:
- 描述:手写数字数据集,包含 60,000 张训练图像和 10,000 张测试图像,每张图像为 28x28 像素的灰度图,10 个类别(0-9)。
- 用途:适合初学者,用于简单的分类任务。
- 加载方式:
from torchvision.datasets import MNIST import torchvision.transforms as transforms dataset = MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
-
CIFAR-10 / CIFAR-100:
- 描述:
- CIFAR-10:包含 60,000 张 32x32 彩色图像,10 个类别(如飞机、汽车、鸟等)。
- CIFAR-100:类似 CIFAR-10,但有 100 个细粒度类别。
- 用途:适合中小规模的分类任务,常用于测试卷积神经网络。
- 加载方式:
from torchvision.datasets import CIFAR10 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
- 描述:
-
ImageNet:
- 描述:大规模图像分类数据集,包含 1,000 个类别,约 120 万训练图像和 5 万验证图像。
- 用途:常用于预训练模型(如 ResNet、VGG)的训练和迁移学习。
- 注意:ImageNet 数据集较大,需手动下载并解压到指定目录,torchvision 不提供自动下载。
- 加载方式:
from torchvision.datasets import ImageNet dataset = ImageNet(root='./data', split='train', transform=transforms.ToTensor())
-
Fashion-MNIST:
- 描述:类似 MNIST,包含 70,000 张 28x28 灰度图像,10 个时尚产品类别(如 T 恤、裤子等)。
- 用途:比 MNIST 更具挑战性,适合分类任务实验。
- 加载方式:与 MNIST 类似。
(2) 目标检测与分割数据集
这些数据集提供图像及其对应的边界框或像素级分割标签,适用于目标检测和图像分割任务。
-
COCO (Common Objects in Context):
- 描述:包含 80 个对象类别,支持目标检测、实例分割、关键点检测和全景分割。提供训练集(约 118,000 张图像)和验证集。
- 用途:广泛用于目标检测和分割模型(如 Faster R-CNN、Mask R-CNN)的训练。
- 加载方式:
from torchvision.datasets import CocoDetection dataset = CocoDetection(root='./data/coco/images', annFile='./data/coco/annotations/instances_train2017.json', transform=transforms.ToTensor())
-
Pascal VOC:
- 描述:包含 20 个对象类别,支持目标检测和图像分割。提供 2007 和 2012 年版本。
- 用途:适合中小规模的目标检测和分割任务。
- 加载方式:
from torchvision.datasets import VOCDetection dataset = VOCDetection(root='./data', year='2012', image_set='train', download=True, transform=transforms.ToTensor())
(3) 其他数据集
- CelebA:人脸数据集,包含 20 万张名人图像,带属性标签(如性别、眼镜等),适合人脸识别或属性分类。
- Cityscapes:城市街景数据集,支持语义分割和实例分割。
- SVHN (Street View House Numbers):街景门牌号数据集,类似 MNIST,但包含彩色图像。
- LSUN:大规模场景理解数据集,支持场景分类和生成任务。
- Kinetics-400 / Kinetics-600:视频数据集,用于视频分类任务。
3. 数据集类的通用参数
大多数 torchvision.datasets 的数据集类都支持以下通用参数:
root(str):数据存储的根目录。train(bool):是否加载训练集(True)或测试/验证集(False)。某些数据集使用split参数(如 ‘train’、‘val’)。download(bool):如果数据集不存在,是否自动下载(注意:ImageNet 和 COCO 不支持自动下载)。transform(callable):应用于图像的变换(如transforms.ToTensor())。target_transform(callable):应用于标签的变换(例如,调整标签格式)。image_set(str):某些数据集(如 VOC)使用此参数指定子集(‘train’、‘val’ 等)。
示例:
dataset = torchvision.datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
)
4. 结合 DataLoader 使用
torchvision.datasets 数据集通常与 PyTorch 的 torch.utils.data.DataLoader 结合使用,以实现批量加载、数据打乱和多线程加载。
-
示例:
from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10 import torchvision.transforms as transforms # 定义变换 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # 加载数据集 dataset = CIFAR10(root='./data', train=True, download=True, transform=transform) # 创建 DataLoader dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=2) # 遍历数据 for images, labels in dataloader: print(images.shape) # 形状为 [batch_size, C, H, W] print(labels.shape) # 形状为 [batch_size] break -
DataLoader 参数:
batch_size:每批次样本数量。shuffle:是否打乱数据。num_workers:加载数据的线程数(建议根据 CPU 核心数设置,Windows 用户可能需设置num_workers=0避免问题)。
5. 自定义数据集
如果标准数据集无法满足需求,torchvision.datasets 提供了工具来加载自定义数据集,例如:
(1) ImageFolder
ImageFolder 是一个便捷的类,用于加载按文件夹组织的图像数据集。假设你的数据集目录结构如下:
root/class1/
root/class2/
每个文件夹对应一个类别,文件夹内包含图像文件。
-
用法:
from torchvision.datasets import ImageFolder import torchvision.transforms as transforms transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) dataset = ImageFolder(root='./custom_data', transform=transform) dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True) -
特点:
- 自动根据文件夹名称分配类别标签。
- 支持
transform和target_transform。 - 适合简单的分类任务。
(2) 自定义 Dataset 类
对于更复杂的场景,可以继承 torch.utils.data.Dataset 创建自定义数据集。
- 示例:
from torch.utils.data import Dataset from PIL import Image import os class CustomDataset(Dataset): def __init__(self, root_dir, transform=None): self.root_dir = root_dir self.transform = transform self.images = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.jpg')] self.labels = [int(f.split('_')[0]) for f in os.listdir(root_dir) if f.endswith('.jpg')] def __len__(self): return len(self.images) def __getitem__(self, idx): img_path = self.images[idx] image = Image.open(img_path).convert('RGB') label = self.labels[idx] if self.transform: image = self.transform(image) return image, label # 使用自定义数据集 dataset = CustomDataset(root_dir='./custom_data', transform=transforms.ToTensor()) dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
6. 高级功能与注意事项
(1) 数据增强
结合 torchvision.transforms,可以为数据集添加数据增强操作,增强模型泛化能力。例如:
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor()
])
dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
(2) 内存与性能优化
- 大数据集:对于 ImageNet 或 COCO 等大型数据集,建议使用 SSD 存储以提高 I/O 速度。
- num_workers:合理设置
DataLoader的num_workers,避免过多线程导致内存溢出。 - 预加载:对于小型数据集,可以将整个数据集加载到内存中以加速训练(需要足够内存)。
(3) 常见问题
- 数据集下载失败:检查网络连接,或手动下载数据集并放置在
root目录。 - 版本兼容性:确保 torchvision 和 PyTorch 版本匹配,建议使用最新版本(截至 2025 年 10 月,最新版本为 torchvision 0.18.0)。
- COCO 数据格式:COCO 数据集需要单独下载图像和注释文件(JSON 格式),并正确指定路径。
7. 学习建议
- 初学者:从 MNIST 或 CIFAR-10 开始,学习如何加载数据集、应用 transforms 和使用 DataLoader。
- 进阶:尝试 COCO 或 VOC 数据集,结合预训练模型(如 Faster R-CNN)进行目标检测。
- 高级:实现自定义数据集,处理复杂的数据格式(如多模态数据或视频)。
- 实践:结合 torchvision.models 和 transforms,构建一个完整的图像分类或检测 pipeline。
8. 总结
torchvision.datasets 模块是计算机视觉任务的强大工具,提供从简单到复杂的数据集支持,涵盖分类、检测、分割等任务。通过与 transforms 和 DataLoader 的结合,可以高效地加载和预处理数据。无论是使用标准数据集还是自定义数据集,torchvision.datasets 都提供了灵活的接口和丰富的功能。
更多推荐
所有评论(0)