首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >一招教你 解决torch.hub.load模型加载问题

一招教你 解决torch.hub.load模型加载问题

作者头像
OpenCV学堂
发布2026-04-02 20:41:57
发布2026-04-02 20:41:57
70
举报

Pytorch Hub模块介绍

PyTorch Hub 支持通过添加一个简单的 hubconf.py 文件,将预训练模型(模型定义和预训练权重)发布到 GitHub 仓库。hubconf.py 可以有多个入口点。每个入口点被定义为一个 Python 函数(例如:您想要发布的预训练模型)。

Resolve the problem stock illustration. Illustration of smart - 27590216
Resolve the problem stock illustration. Illustration of smart - 27590216

预训练权重可以存储在 GitHub 仓库,或通过 torch.hub.load_state_dict_from_url() 加载。如果小于 2GB,建议将其附加到项目仓库中,并使用仓库中的 URL。通过hub方式加载模型代码实现如下:

代码语言:javascript
复制
if pretrained:
  # For checkpoint saved in local GitHub repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
  dirname = os.path.dirname(__file__)
  checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
  state_dict = torch.load(checkpoint)
  model.load_state_dict(state_dict)
  # For checkpoint saved elsewhere
  checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
  model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))

Pytorch Hub使用技巧

PyTorch Hub 提供了方便的 API,通过 torch.hub.list() 查询所有可用的模型;通过 torch.hub.help() 显示文档字符串和示例;使用 torch.hub.load() 加载预训练模型

实际使用过程中,最常用的是load函数,相关函数与参数解释如下:

代码语言:javascript
复制
torch.hub.load(
  repo_or_dir, 
  model, 
  source='github',
  trust_repo=None, 
  force_reload=False,
  verbose=True, 
  skip_validation=False
)

source声明为'github'表示从远程下载代码执行,'local'表示从本地代码库加载,repo_or_dir表示远程加载URL或者本地的文件夹路径;model表示模型名称。模型完整的路径声明在hubconf.py文件中。

整个加载过程,就是先找hubconf.py文件,然后通过

torch.load加载pth文件,然在通过model.load_state_dict函数实现权重参数加载。

注意:加载模型的使用场景很常见,但也可以用于加载其他对象,如 tokenizers、损失函数等。

默认情况下下载以后的模型保存在当前用户的目录下。

代码语言:javascript
复制
~/.cache/torch/hub

本地或者远程加载代码演示

代码语言:javascript
复制
# from a github repo
repo = "pytorch/vision"
model = torch.hub.load(
    repo, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1"
)
# from a local directory
path = "/some/local/path/pytorch/vision"
model = torch.hub.load(path, "resnet50", weights="ResNet50_Weights.DEFAULT")

关键技巧:

经常遇到torch.hub.load加载模型失败的原因是无法访问github导致的,所以推荐先把相关代码库下载到本地,然后从本地加载即可解决。然后还可以修改相关代码让,模型加载也继续本地化,这样就再不用担心网络问题导致模型加载失败。

学习OpenVINO2025

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2025-09-05,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 OpenCV学堂 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档