如何高效保存与加载PyTorch模型:掌握PyTorch Image Models序列化最佳实践

【免费下载链接】pytorch-image-models The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more 【免费下载链接】pytorch-image-models 项目地址: https://gitcode.com/GitHub_Trending/py/pytorch-image-models

PyTorch Image Models(timm)是一个包含大量PyTorch图像编码器和骨干网络的开源项目,提供了ResNet、EfficientNet、Vision Transformer等多种模型的训练、评估、推理和导出脚本,以及预训练权重。本文将详细介绍在timm项目中进行模型序列化(保存与加载)的最佳实践,帮助新手和普通用户轻松掌握模型持久化的核心技巧。

模型序列化的两种核心方法

在PyTorch中,模型序列化主要有两种常用方法:保存/加载模型状态字典(State Dict)和保存/加载整个模型。timm项目中广泛采用状态字典方法,这也是官方推荐的最佳实践。

方法一:保存与加载模型状态字典(推荐)

状态字典仅包含模型的参数权重,不包含模型结构信息,因此需要先创建模型实例再加载权重。这种方法更加灵活,支持跨设备和版本迁移。

保存模型状态字典

# 保存模型权重
torch.save(model.state_dict(), 'model_weights.pth')

在timm项目中,你可以在convert/convert_from_mxnet.py文件的第66行找到类似实现:torch.save(torch_net.state_dict(), torch_filename)

加载模型状态字典

# 创建模型实例
model = create_model('resnet50', pretrained=False)
# 加载权重
model.load_state_dict(torch.load('model_weights.pth'))
# 设置为评估模式
model.eval()

timm提供了便捷的模型加载辅助函数,如_helpers.py中的_torch_load函数,该函数支持安全加载权重并处理不同设备映射。

方法二:保存与加载整个模型(不推荐)

这种方法会将整个模型对象(包括结构和权重)保存到文件中,虽然使用简单但灵活性较差,可能导致版本兼容性问题。

# 保存整个模型(不推荐)
torch.save(model, 'entire_model.pth')

# 加载整个模型(不推荐)
model = torch.load('entire_model.pth')

timm项目中的模型加载最佳实践

timm项目提供了多种便捷的模型加载方式,让你轻松加载预训练模型或自定义权重。

使用timm的模型工厂函数加载

timm的核心优势之一是提供了统一的模型创建接口,你可以通过create_model函数轻松创建并加载模型:

import timm

# 加载预训练模型
model = timm.create_model('resnet50', pretrained=True)
# 加载自定义权重
model = timm.create_model('resnet50', pretrained=False)
model.load_state_dict(torch.load('my_custom_weights.pth'))

create_model函数定义在timm/models/_factory.py中,支持从Hugging Face Hub或本地路径加载模型配置。

安全加载权重文件

timm提供了安全的权重加载机制,特别是在_helpers.py中实现的_torch_load函数,支持weights_only参数,防止加载恶意代码:

# 安全加载权重文件
checkpoint = _torch_load('model_weights.pth', map_location='cpu', weights_only=True)
model.load_state_dict(checkpoint)

这一安全特性在tests/test_checkpoint_loading.py中有详细的测试用例,确保加载过程的安全性。

模型保存的高级技巧

保存额外训练信息

在实际训练中,除了模型权重外,通常还需要保存优化器状态、学习率调度器、训练轮次等信息:

# 保存训练状态
save_state = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'loss': loss,
}
torch.save(save_state, 'training_checkpoint.pth')

timm的checkpoint_saver.py工具类实现了完整的检查点保存功能,支持自动管理检查点文件。

使用安全张量格式(Safetensors)

timm还支持使用Safetensors格式保存和加载模型权重,这是一种更安全、更快的张量存储格式:

# 保存为Safetensors格式
safetensors.torch.save_file(model.state_dict(), 'model_weights.safetensors')

# 加载Safetensors格式权重
checkpoint = safetensors.torch.load_file('model_weights.safetensors')
model.load_state_dict(checkpoint)

avg_checkpoints.pyclean_checkpoint.py中可以看到timm对Safetensors格式的支持,这也是项目推荐的权重存储格式。

常见问题与解决方案

问题1:加载模型时出现设备不匹配

解决方案:使用map_location参数指定加载设备:

# 加载到CPU
checkpoint = torch.load('model_weights.pth', map_location='cpu')
# 加载到GPU
checkpoint = torch.load('model_weights.pth', map_location='cuda:0')

timm的_torch_load函数已经内置了设备映射功能,你可以在_helpers.py中查看实现细节。

问题2:权重文件与模型结构不匹配

解决方案:使用strict=False参数忽略不匹配的键:

model.load_state_dict(torch.load('model_weights.pth'), strict=False)

或者使用timm提供的load_state_dict辅助函数,该函数支持部分加载和权重转换。

问题3:大型模型加载速度慢

解决方案

  1. 使用Safetensors格式代替传统的PyTorch格式
  2. 分阶段加载权重
  3. 使用torch.loadmmap模式(PyTorch 1.10+)

timm的模型加载流程在_hub.py中进行了优化,支持高效加载大型模型权重。

总结:模型序列化最佳实践清单

为了确保模型序列化的安全性和高效性,建议遵循以下最佳实践:

  1. 优先使用状态字典:总是保存和加载模型的state_dict,而非整个模型对象
  2. 使用安全加载方式:通过weights_only=True参数防止恶意代码执行
  3. 采用Safetensors格式:对于新的权重文件,优先使用Safetensors格式
  4. 保存完整训练状态:除模型权重外,保存优化器、调度器等训练信息
  5. 注意设备兼容性:使用map_location参数确保跨设备加载的兼容性
  6. 利用timm工具函数:使用timm提供的_torch_loadload_state_dict等辅助函数简化加载过程

通过遵循这些最佳实践,你可以在timm项目中高效、安全地进行模型的保存和加载,为模型部署和迁移提供可靠保障。无论是训练新模型、微调预训练模型,还是在生产环境中部署模型,正确的序列化方法都是确保模型性能和稳定性的关键步骤。

【免费下载链接】pytorch-image-models The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more 【免费下载链接】pytorch-image-models 项目地址: https://gitcode.com/GitHub_Trending/py/pytorch-image-models

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐