pytorch 使用预训练模型如resnet、vgg等并修改部分结构 |
您所在的位置:网站首页 › vgg网络模型参数 › pytorch 使用预训练模型如resnet、vgg等并修改部分结构 |
pytorch 使用预训练模型并修改部分结构
在一些常见的如检测、分类等计算机视觉任务中,基于深度学习的方法取得了很好的结果,其中一些经典模型也往往成为相关任务及比赛的baseline。在pytorch的视觉库torchvision中,提供了models模块供我们直接调用这些经典网络,如VGG,Resnet等。使用中往往不能直接使用现成的模型,需要进行一些修改。实际上我们可以很方便的在pytorch中使用并修改模型。 1. 直接使用pytorch中的经典模型直接通过models调用即可,如 from torchvision import models res101 = models.resnet101(pretrained=True) vgg19 = models.vgg19_bn(pretrained=True)如下如所示,models模块的__init__.py 包含了一系列不同的网络结构 以及网络模型的不同层数的结构,如resnet50, resnet101, vgg16, vgg19等 我们只需查阅手册或源码寻找是否有这个网络模型,有的话直接拿来用即可。参数 pre_trained为True时表示模型参数是在ImageNet预训练过的,否则就是随机初始化的参数。 在首次使用时,pytorch会自动下载模型文件,保存在用户cache目录内 在参加一些图像检测、分类、分割比赛时,或者一些不需要大幅修改网络结构的场景,可以直接采用pytorch自带的网络结构,无需自行搭建。 2. 模型结构的修改在分类问题上,模型的最后一层一般是一个全连接层,输出的神经元个数就是类别信息,最后输出结果是一个浮点向量,大小表示某一类别的可能性,数值越大说明越倾向于分为该类。 显然直接使用预训练的网络不加修改那么总类别数就是固定的,当我们使用的场景类别数不一致时,就要自行修改模型的最后一层。那么如何进行替换和修改呢? 我们知道,在自定义网络结构时,通常是: class myModel(Module): def __init__(self): # 模型结构 self.conv1 = xxxx self.fc1 = xxxxx self.m = nn.Sequential(a,b,c...) def forward(self,x): # 前向传播这样的形式。换言之,模型的每一层都记录在了这个模型类的实例的成员变量里。因此只要我们知道要修改的那一层叫什么名字,就能够进行修改。 例如 对resnet,最后一层全连接层就叫fc,所以我们可以: res101 = models.resnet101(pretrained=True) numFit = res101.fc.in_features res101.fc = nn.Linear(numFit, numClass)res101.fc就是这个网络的最后一层全连接层,in_feature是输出神经元数量,我们将它修改为输入神经元不变(也不能变,不然就出错了)输出神经元为我们需要的类别数的全连接网络。 有时候分类任务不光要输出类别,也要输出置信度,通常置信度就是分类为这个类别的概率,既然是概率,就要满足 0 ≤ P i ≤ 1 , Σ i = 1 N P i = 1 0\le P_{i}\le 1, \Sigma_{i=1}^{N}P_{i} = 1 0≤Pi≤1,Σi=1NPi=1 由于全连接网络直接输出的结果往往不能称之为“置信度”(只有大小之分,不满足0-1之间,和也不是1),通常会在后面加一层softmax作为激活函数,这样输出结果就是一个概率值了: res101.fc = nn.Sequential(nn.Linear(numFit, numClass), nn.Softmax(dim=1))以上是最简单的模型某一层就是一个单独的成员变量的情况。 那如果模型把好多东西塞进了一个Sequential怎么办呢? 例如vgg: 我们在torchvision/moduls/vgg.py 中找到VGG类的定义: 显然分类相关的3层全连接、激活函数、dropout都在一个Sequential类、名字叫做classifier的成员变量里,这种情况,我们需要把整个classifier都复写吗? 答案是不需要的。 我们知道Sequential同样继承自nn.Module类,这个类有一个成员变量叫做_modules 这是一个有序字典,存放了模块名称 - 模块内容 的键值对。 每次新添加一层,都会做一次 self._modules[name] = module这个操作。 这个name这里很有意思,一般我们很少给每一层网络都起一个名字,那默认的名字实际上是该模块索引的字符串形式。比如上述的vgg的classifier,它的第一层全连接,名字叫做’1’,最后一层name是’6’。这个名字部分以后有时间专门讨论一下。 回到这个问题,Sequential继承自nn.Module,自然也有这个字典。 所以对于vgg,我们可以: vgg19 = models.vgg19_bn(pretrained=True) vgg19.classifier._modules['6'] = nn.Sequential(nn.Linear(4096, numClass), nn.Softmax(dim=1))就可以将最后一层全连接层替换掉了。 中间其他层也可以用类似的方式替换。 小结总结来说,pytorch提供的网络模型还是比较实用的,对于不需要大幅修改的网络结构只要直接调用再局部修改就可以,满足一了一些简单的深度学习需求场景,可以不需要自己重新写一遍网路结构了。 在修改方面,基于pytorch模型的定义方式,我们只要知道其模型结构,这一点可以直接查找pytorch这部分的源码,了解到成员变量的名字,如果是Sequential,可以再通过_modules这个字典查找,都将能够较容易的找到被修改的那一层,直接替换成我们需要的结构即可。当然,替换后与训练的权重就不见了,取而代之的是随机初始化权重。 |
今日新闻 |
推荐新闻 |
CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3 |