pytorch集锦(4)

您所在的位置:网站首页 hammerheadshark怎么读 pytorch集锦(4)

pytorch集锦(4)

2024-06-11 20:42| 来源: 网络整理| 查看: 265

目录 加载训练好的模型下载模型权重图像预处理打开要预测的图像传递图像运行模型下载数据打开imagenet_classes.txt预测结果前5个最可能分类

加载训练好的模型 pip3 install pillow >>> from torchvision import models >>> dir(models) ['AlexNet', 'DenseNet', 'Inception3', 'ResNet', 'SqueezeNet', 'VGG', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'alexnet', 'densenet', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'inception', 'inception_v3', 'resnet', 'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'squeezenet', 'squeezenet1_0', 'squeezenet1_1', 'vgg', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn']

出现 这个错误

ImportError: cannot import name 'PILLOW_VERSION' from 'PIL' (/usr/lib64/python3.11/site-packages/PIL/__init__.py)

需要将对应出错文件的PILLOW_VERSION改为__version__

PILLOW_VERSION在Pillow 7.0.0之后的版本被移除了

下载模型权重 >>> resnet=models.resnet101(pretrained=True) Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /home/spx/.cache/torch/hub/checkpoints/resnet101-5d3b4d8f.pth 100.0% >>> resnet ResNet( (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (layer1): Sequential( (0): Bottleneck( .... 图像预处理 >>>from torchvision import transforms >>> prporocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.285,0.456,0.406],std=[0.229,0.224,0.225])]

1、定义了转换操作,允许快速定义基本预重函数的管道。 2、将输入图像缩放到256*256个像素,围绕中心将图像裁剪为224*224个像素。 3、图像转换为一个张量,对RGB分量进行归一化处理。

打开要预测的图像 >>from PIL import Image >>> img=Image.open('/home/spx/learn/pic/1.jpg') >>> img >>> img.show()

在这里插入图片描述

传递图像 >>> img_t=preprocess(img) >>>> import torch >>> batch_t=torch.unsqueeze(img_t,0) 运行模型 >>> resnet.eval() >>>> out=resnet(batch_t) 下载数据

https://image-net.org/里找到 imagenet_classes.txt下载,这是标签文件。 或者 https://gitee.com/lonerlin/classification/blob/master/imagenet_classes.txt https://github.com/ethereon/caffe-tensorflow/blob/master/examples/imagenet/imagenet-classes.txt

这里, 将它复制下来。 新建一个imagenet_classes.txt文件,粘贴进去

打开imagenet_classes.txt >>> with open('/home/spx/learn/pic/imagenet_classes.txt') as f: ... labels=[line.strip() for line in f.readlines()] ... >>> labels ['tench, Tinca tinca', 'goldfish, Carassius auratus', 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias', 'tiger shark, Galeocerdo cuvieri', 'hammerhead, hammerhead shark', 'electric ray, crampfish, numbfish, torpedo', 'stingray', 'cock', 'hen', 'ostrich, Struthio camelus', 'brambling, Fringilla montifringilla', 'goldfinch, Carduelis carduelis', 'house finch, linnet, Carpodacus mexicanus', 'junco, snowbird', 'indigo bunting, indigo f .... 预测结果 >>> _,index=torch.max(out,1) >>> percentage=torch.nn.functional.softmax(out,dim=1)[0]*100 >>> labels[index[0]] 'hog, pig, grunter, squealer, Sus scrofa' >>> percentage[index[0]].item() 60.67759323120117

pig预测正确

前5个最可能分类 >>> _,indices=torch.sort(out,descending=True) >>> [(labels[idx],percentage[idx].item()) for idx in indices[0][:5]] [('hog, pig, grunter, squealer, Sus scrofa', 60.67759323120117), ('weasel', 9.75589656829834), ('guinea pig, Cavia cobaya', 8.112009048461914), ('black-footed ferret, ferret, Mustela nigripes', 5.257884979248047), ('polecat, fitch, foulmart, foumart, Mustela putorius', 4.569345474243164)]


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3