torchvision.datasetstorchvision 库中的一个核心模块,专门用于提供标准化的计算机视觉数据集。这些数据集可以直接用于深度学习任务,如图像分类、目标检测、图像分割等。torchvision.datasets 中的每个数据集类都是 PyTorch 的 torch.utils.data.Dataset 类的子类,兼容 PyTorch 的 DataLoader,支持灵活的数据加载、预处理和数据增强。

以下是对 torchvision.datasets 模块的详尽介绍,包括其功能、常用数据集、用法、参数、自定义数据集支持,以及代码示例和注意事项。


1. torchvision.datasets 模块概述

torchvision.datasets 提供了一系列预定义的计算机视觉数据集,涵盖了图像分类、目标检测、图像分割、视频分类等任务。这些数据集通常包含图像数据和对应的标签或注释,适合用于模型训练、验证和测试。

主要特点
  • 标准化接口:所有数据集类都实现了 __getitem____len__ 方法,符合 PyTorch 的 Dataset 接口,可以无缝与 DataLoader 结合。
  • 内置下载功能:大部分数据集支持自动下载(通过设置 download=True)。
  • 灵活的变换支持:通过 transformtarget_transform 参数,可以对图像和标签应用预处理或数据增强。
  • 多样化的数据集:涵盖小型数据集(如 MNIST、CIFAR-10)到大规模数据集(如 ImageNet、COCO)。

2. 常用数据集

torchvision.datasets 提供了多种数据集,适用于不同的计算机视觉任务。以下是常用的数据集及其特点:

(1) 图像分类数据集

这些数据集主要用于图像分类任务,包含图像和对应的类别标签。

  • MNIST

    • 描述:手写数字数据集,包含 60,000 张训练图像和 10,000 张测试图像,每张图像为 28x28 像素的灰度图,10 个类别(0-9)。
    • 用途:适合初学者,用于简单的分类任务。
    • 加载方式
      from torchvision.datasets import MNIST
      import torchvision.transforms as transforms
      
      dataset = MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
      
  • CIFAR-10 / CIFAR-100

    • 描述
      • CIFAR-10:包含 60,000 张 32x32 彩色图像,10 个类别(如飞机、汽车、鸟等)。
      • CIFAR-100:类似 CIFAR-10,但有 100 个细粒度类别。
    • 用途:适合中小规模的分类任务,常用于测试卷积神经网络。
    • 加载方式
      from torchvision.datasets import CIFAR10
      
      transform = transforms.Compose([
          transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
      ])
      dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
      
  • ImageNet

    • 描述:大规模图像分类数据集,包含 1,000 个类别,约 120 万训练图像和 5 万验证图像。
    • 用途:常用于预训练模型(如 ResNet、VGG)的训练和迁移学习。
    • 注意:ImageNet 数据集较大,需手动下载并解压到指定目录,torchvision 不提供自动下载。
    • 加载方式
      from torchvision.datasets import ImageNet
      
      dataset = ImageNet(root='./data', split='train', transform=transforms.ToTensor())
      
  • Fashion-MNIST

    • 描述:类似 MNIST,包含 70,000 张 28x28 灰度图像,10 个时尚产品类别(如 T 恤、裤子等)。
    • 用途:比 MNIST 更具挑战性,适合分类任务实验。
    • 加载方式:与 MNIST 类似。
(2) 目标检测与分割数据集

这些数据集提供图像及其对应的边界框或像素级分割标签,适用于目标检测和图像分割任务。

  • COCO (Common Objects in Context)

    • 描述:包含 80 个对象类别,支持目标检测、实例分割、关键点检测和全景分割。提供训练集(约 118,000 张图像)和验证集。
    • 用途:广泛用于目标检测和分割模型(如 Faster R-CNN、Mask R-CNN)的训练。
    • 加载方式
      from torchvision.datasets import CocoDetection
      
      dataset = CocoDetection(root='./data/coco/images', annFile='./data/coco/annotations/instances_train2017.json',
                              transform=transforms.ToTensor())
      
  • Pascal VOC

    • 描述:包含 20 个对象类别,支持目标检测和图像分割。提供 2007 和 2012 年版本。
    • 用途:适合中小规模的目标检测和分割任务。
    • 加载方式
      from torchvision.datasets import VOCDetection
      
      dataset = VOCDetection(root='./data', year='2012', image_set='train', download=True,
                            transform=transforms.ToTensor())
      
(3) 其他数据集
  • CelebA:人脸数据集,包含 20 万张名人图像,带属性标签(如性别、眼镜等),适合人脸识别或属性分类。
  • Cityscapes:城市街景数据集,支持语义分割和实例分割。
  • SVHN (Street View House Numbers):街景门牌号数据集,类似 MNIST,但包含彩色图像。
  • LSUN:大规模场景理解数据集,支持场景分类和生成任务。
  • Kinetics-400 / Kinetics-600:视频数据集,用于视频分类任务。

3. 数据集类的通用参数

大多数 torchvision.datasets 的数据集类都支持以下通用参数:

  • root (str):数据存储的根目录。
  • train (bool):是否加载训练集(True)或测试/验证集(False)。某些数据集使用 split 参数(如 ‘train’、‘val’)。
  • download (bool):如果数据集不存在,是否自动下载(注意:ImageNet 和 COCO 不支持自动下载)。
  • transform (callable):应用于图像的变换(如 transforms.ToTensor())。
  • target_transform (callable):应用于标签的变换(例如,调整标签格式)。
  • image_set (str):某些数据集(如 VOC)使用此参数指定子集(‘train’、‘val’ 等)。

示例:

dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
)

4. 结合 DataLoader 使用

torchvision.datasets 数据集通常与 PyTorch 的 torch.utils.data.DataLoader 结合使用,以实现批量加载、数据打乱和多线程加载。

  • 示例

    from torch.utils.data import DataLoader
    from torchvision.datasets import CIFAR10
    import torchvision.transforms as transforms
    
    # 定义变换
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    # 加载数据集
    dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
    
    # 创建 DataLoader
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=2)
    
    # 遍历数据
    for images, labels in dataloader:
        print(images.shape)  # 形状为 [batch_size, C, H, W]
        print(labels.shape)  # 形状为 [batch_size]
        break
    
  • DataLoader 参数

    • batch_size:每批次样本数量。
    • shuffle:是否打乱数据。
    • num_workers:加载数据的线程数(建议根据 CPU 核心数设置,Windows 用户可能需设置 num_workers=0 避免问题)。

5. 自定义数据集

如果标准数据集无法满足需求,torchvision.datasets 提供了工具来加载自定义数据集,例如:

(1) ImageFolder

ImageFolder 是一个便捷的类,用于加载按文件夹组织的图像数据集。假设你的数据集目录结构如下:

root/class1/
root/class2/

每个文件夹对应一个类别,文件夹内包含图像文件。

  • 用法

    from torchvision.datasets import ImageFolder
    import torchvision.transforms as transforms
    
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    
    dataset = ImageFolder(root='./custom_data', transform=transform)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
    
  • 特点

    • 自动根据文件夹名称分配类别标签。
    • 支持 transformtarget_transform
    • 适合简单的分类任务。
(2) 自定义 Dataset 类

对于更复杂的场景,可以继承 torch.utils.data.Dataset 创建自定义数据集。

  • 示例
    from torch.utils.data import Dataset
    from PIL import Image
    import os
    
    class CustomDataset(Dataset):
        def __init__(self, root_dir, transform=None):
            self.root_dir = root_dir
            self.transform = transform
            self.images = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.jpg')]
            self.labels = [int(f.split('_')[0]) for f in os.listdir(root_dir) if f.endswith('.jpg')]
    
        def __len__(self):
            return len(self.images)
    
        def __getitem__(self, idx):
            img_path = self.images[idx]
            image = Image.open(img_path).convert('RGB')
            label = self.labels[idx]
    
            if self.transform:
                image = self.transform(image)
    
            return image, label
    
    # 使用自定义数据集
    dataset = CustomDataset(root_dir='./custom_data', transform=transforms.ToTensor())
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
    

6. 高级功能与注意事项

(1) 数据增强

结合 torchvision.transforms,可以为数据集添加数据增强操作,增强模型泛化能力。例如:

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor()
])
dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
(2) 内存与性能优化
  • 大数据集:对于 ImageNet 或 COCO 等大型数据集,建议使用 SSD 存储以提高 I/O 速度。
  • num_workers:合理设置 DataLoadernum_workers,避免过多线程导致内存溢出。
  • 预加载:对于小型数据集,可以将整个数据集加载到内存中以加速训练(需要足够内存)。
(3) 常见问题
  • 数据集下载失败:检查网络连接,或手动下载数据集并放置在 root 目录。
  • 版本兼容性:确保 torchvisionPyTorch 版本匹配,建议使用最新版本(截至 2025 年 10 月,最新版本为 torchvision 0.18.0)。
  • COCO 数据格式:COCO 数据集需要单独下载图像和注释文件(JSON 格式),并正确指定路径。

7. 学习建议

  • 初学者:从 MNIST 或 CIFAR-10 开始,学习如何加载数据集、应用 transforms 和使用 DataLoader。
  • 进阶:尝试 COCO 或 VOC 数据集,结合预训练模型(如 Faster R-CNN)进行目标检测。
  • 高级:实现自定义数据集,处理复杂的数据格式(如多模态数据或视频)。
  • 实践:结合 torchvision.modelstransforms,构建一个完整的图像分类或检测 pipeline。

8. 总结

torchvision.datasets 模块是计算机视觉任务的强大工具,提供从简单到复杂的数据集支持,涵盖分类、检测、分割等任务。通过与 transformsDataLoader 的结合,可以高效地加载和预处理数据。无论是使用标准数据集还是自定义数据集,torchvision.datasets 都提供了灵活的接口和丰富的功能。

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐