PyTorch流行的预训练模型和数据集列表pytorch

您所在的位置:网站首页 pytorch加载预训练模型 PyTorch流行的预训练模型和数据集列表pytorch

PyTorch流行的预训练模型和数据集列表pytorch

#PyTorch流行的预训练模型和数据集列表pytorch| 来源: 网络整理| 查看: 265

pytorch-playground包含基础预训练模型和pytorch中的数据集(MNIST,SVHN,CIFAR10,CIFAR100,STL10,AlexNet,VGG16,VGG19,ResNet,Inception,SqueezeNet)

这是pytorch初学者的游乐场(即资源列表,你可以随意使用如下模型),其中包含流行数据集的预定义模型。目前支持如下模型:

mnist,svhn cifar10,cifar100 stl10 alexnet vgg16,vgg16_bn,vgg19,vgg19_bn resnet18,resnet34,resnet50,resnet101,resnet152 squeezenet_v0,squeezenet_v1 inception_v3

下面是MNIST数据集的例子。下面的代码将自动下载数据集和预先训练的模型。

import torch from torch.autograd import Variable from utee import selector model_raw, ds_fetcher, is_imagenet = selector.select('mnist') ds_val = ds_fetcher(batch_size=10, train=False, val=True) for idx, (data, target) in enumerate(ds_val): data = Variable(torch.FloatTensor(data)).cuda() output = model_raw(data)

另外,如果想在mnist上训练MLP模型,只需运行python mnist/train.py即可。

一、安装 pytorch(> = 0.1.11)和官方网站的torchvision,例如cuda8.0 for python3.5 pip install http://download.pytorch.org/whl/cu80/torch-0.1.12.post2-cp35-cp35m-linux_x86_64.whl pip install torchvision tqdm pip install tqdm OpenCV conda install -c menpo opencv3 设置PYTHONPATH export PYTHONPATH=/path/to/pytorch-playground:$PYTHONPATH 二、ImageNet数据集

我们提供224x224x3大小的预训练imagenet验证数据集。我们首先将较短尺寸的图像调整为256,然后在中心剪裁224x224图像。然后我们将裁剪后的图像编码为jpg字符串并转储到pickle。

cd script 下载val224_compressed.pkl axel http://ml.cs.tsinghua.edu.cn/~chenxi/dataset/val224_compressed.pkl python convert.py 三、量化

我们还提供了一个简单的演示,使用几种方法将这些模型量化为指定的位宽,包括线性方法,最小最大值方法和非线性方法。

python quantize.py --type cifar10 --quant_method linear --param_bits 8 --fwd_bits 8 --bn_bits 8 --ngpu 1

四、Top1准确度

我们用线性量化方法评估流行数据集和模型的性能。BN中的运行均值和运行方差的比特宽度对于所有结果都是10比特。(32-float除外)

模型32-float12-bit10-bit8-bit6-bitMNIST98.4298.4398.4498.4498.32SVHN96.0396.0396.0496.0295.46CIFAR1093.7893.7993.8093.5890.86CIFAR10074.2774.2174.1973.7066.32STL1077.5977.6577.7077.5973.40AlexNet55.70/78.4255.66/78.4155.54/78.3954.17/77.2918.19/36.25VGG1670.44/89.4370.45/89.4370.44/89.3369.99/89.1753.33/76.32VGG1971.36/89.9471.35/89.9371.34/89.8870.88/89.6256.00/78.62ResNet1868.63/88.3168.62/88.3368.49/88.2566.80/87.2019.14/36.49ResNet3472.50/90.8672.46/90.8272.45/90.8571.47/90.0032.25/55.71ResNet5074.98/92.1774.94/92.1274.91/92.0972.54/90.442.43/5.36ResNet10176.69/93.3076.66/93.2576.22/92.9065.69/79.541.41/1.18ResNet15277.55/93.5977.51/93.6277.40/93.5474.95/92.469.29/16.75SqueezeNetV056.73/79.3956.75/79.4056.70/79.2753.93/77.0414.21/29.74SqueezeNetV156.52/79.1356.52/79.1556.24/79.0354.56/77.3317.10/32.46InceptionV376.41/92.7876.43/92.7176.44/92.7373.67/91.341.50/4.82

注意:`ImageNet 32-float`模型直接来自`torchvision` ### 五、定义参数 在`quantize.py`可以定义参数 参数默认值描述 & 参数typecifar10mnist,svhn,cifar10,cifar100,stl10,alexnet,vgg16,vgg16_bn,vgg19,vgg19_bn,resent18,resent34,resnet50,resnet101,resnet152,squeezenet_v0,squeezenet_v1,inception_v3quant_methodlinearquantization method:linear,minmax,log,tanhparam_bits8bit-width of weights and biasfwd_bits8bit-width of activationbn_bits32bit-width of running mean and running vairanceoverflow_rate0.0overflow rate threshold for linear quantization methodn_samples20number of samples to make statistics for activation

项目地址:[aaron-xichen/pytorch-playground](https://github.com/aaron-xichen/pytorch-playground) 原创文章,转载请注明 :PyTorch流行的预训练模型和数据集列表pytorch-playground - pytorch中文网 原文出处: https://ptorch.com/news/171.html 问题交流群 :168117787


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3