
Pytorch Hub模块介绍
PyTorch Hub 支持通过添加一个简单的 hubconf.py 文件,将预训练模型(模型定义和预训练权重)发布到 GitHub 仓库。hubconf.py 可以有多个入口点。每个入口点被定义为一个 Python 函数(例如:您想要发布的预训练模型)。

预训练权重可以存储在 GitHub 仓库,或通过 torch.hub.load_state_dict_from_url() 加载。如果小于 2GB,建议将其附加到项目仓库中,并使用仓库中的 URL。通过hub方式加载模型代码实现如下:
if pretrained:
# For checkpoint saved in local GitHub repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
dirname = os.path.dirname(__file__)
checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
state_dict = torch.load(checkpoint)
model.load_state_dict(state_dict)
# For checkpoint saved elsewhere
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))Pytorch Hub使用技巧
PyTorch Hub 提供了方便的 API,通过 torch.hub.list() 查询所有可用的模型;通过 torch.hub.help() 显示文档字符串和示例;使用 torch.hub.load() 加载预训练模型
实际使用过程中,最常用的是load函数,相关函数与参数解释如下:
torch.hub.load(
repo_or_dir,
model,
source='github',
trust_repo=None,
force_reload=False,
verbose=True,
skip_validation=False
)source声明为'github'表示从远程下载代码执行,'local'表示从本地代码库加载,repo_or_dir表示远程加载URL或者本地的文件夹路径;model表示模型名称。模型完整的路径声明在hubconf.py文件中。
整个加载过程,就是先找hubconf.py文件,然后通过
torch.load加载pth文件,然在通过model.load_state_dict函数实现权重参数加载。
注意:加载模型的使用场景很常见,但也可以用于加载其他对象,如 tokenizers、损失函数等。
默认情况下下载以后的模型保存在当前用户的目录下。
~/.cache/torch/hub本地或者远程加载代码演示
# from a github repo
repo = "pytorch/vision"
model = torch.hub.load(
repo, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1"
)
# from a local directory
path = "/some/local/path/pytorch/vision"
model = torch.hub.load(path, "resnet50", weights="ResNet50_Weights.DEFAULT")关键技巧:
经常遇到torch.hub.load加载模型失败的原因是无法访问github导致的,所以推荐先把相关代码库下载到本地,然后从本地加载即可解决。然后还可以修改相关代码让,模型加载也继续本地化,这样就再不用担心网络问题导致模型加载失败。
学习OpenVINO2025