Pytorch预训练模型下载并加载(以VGG为例)自定义路径

您所在的位置:网站首页 5e怎么自己设置路径 Pytorch预训练模型下载并加载(以VGG为例)自定义路径

Pytorch预训练模型下载并加载(以VGG为例)自定义路径

2024-06-09 23:33| 来源: 网络整理| 查看: 265

简述

一般来讲,Pytorch用torchvision调用vgg之类的模型话,如果电脑在cache(Pytorch硬编码的一个地址)(如果在环境变量中添加了TORCH_HOME 和TORCH_MODEL_ZOO的话,就是在这两个位置的联合的路径下,比如TORCH_MODEL_ZOO\model)否则就是在TORCH_HOME\models或者是~/.torch/models

比如,我的就是C:\Users\lijy2/.torch\models\vgg11-bbd30ac9.pth。

这很有可能并不是我们想要的下载模型放的地址,或者是这样的下载方式很慢等等。

而且这个地址不可以很容易的直接调用,非常不方便。

这点,在我现在用pytorch版本还是github上的最新版本都是没有做类似的改进的。

但是这种设计(可能对我这种强迫症来说),是有需求的。

解决办法

首先,先处理下载的问题。

读了下源码,是使用import torch.utils.model_zoo as model_zoo里面的函数来加载数据。 整理了下源码中涉及的这一部分

from urllib.parse import urlparse import torch.utils.model_zoo as model_zoo import re import os def download_model(url, dst_path): parts = urlparse(url) filename = os.path.basename(parts.path) HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') hash_prefix = HASH_REGEX.search(filename).group(1) model_zoo._download_url_to_file(url, os.path.join(dst_path, filename), hash_prefix, True) return filename

调用实例

model_urls = { 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', } import os path = 'D:/Software/DataSet/models/vgg' if not (os.path.exists(path)): os.makedirs(path) for url in model_urls.values(): download_model(url, path)

输出

100%|███████████████████████████████████████████████████████████████| 531456000/531456000 [01:14


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3