使用vgg16训练自己的数据集进行图像分类

您所在的位置:网站首页 vgg16层数 使用vgg16训练自己的数据集进行图像分类

使用vgg16训练自己的数据集进行图像分类

2023-06-17 22:47| 来源: 网络整理| 查看: 265

使用vgg16训练自己的数据集 数据集的准备

先创建一个data文件夹,在data文件夹中再建一个photos文件用于存放图片。因为我们训练的是分类网络,所以在data文件中再建几个文件夹,文件夹的命名就是我们的类别名字(如图:我是一个三分类),每一个文件夹里放的就对应的图片。 photos文件夹

数据集的处理 import os from shutil import copy import random def mkfile(file): if not os.path.exists(file): os.makedirs(file) # 获取 photos 文件夹下除 .txt 文件以外所有文件夹名(即3种分类的类名) file_path = 'data/photos' flower_class = [cla for cla in os.listdir(file_path) if ".txt" not in cla] # 创建 训练集train 文件夹,并由3种类名在其目录下创建3个子目录 mkfile('data/train') for cla in flower_class: mkfile('data/train/' + cla) # 创建 验证集val 文件夹,并由3种类名在其目录下创建3个子目录 mkfile('data/val') for cla in flower_class: mkfile('data/val/' + cla) # 划分比例,训练集 : 验证集 = 9 : 1 split_rate = 0.1 # 遍历3种花的全部图像并按比例分成训练集和验证集 for cla in flower_class: cla_path = file_path + '/' + cla + '/' # 某一类别动作的子目录 images = os.listdir(cla_path) # iamges 列表存储了该目录下所有图像的名称 num = len(images) eval_index = random.sample(images, k=int(num * split_rate)) # 从images列表中随机抽取 k 个图像名称 for index, image in enumerate(images): # eval_index 中保存验证集val的图像名称 if image in eval_index: image_path = cla_path + image new_path = 'data/val/' + cla copy(image_path, new_path) # 将选中的图像复制到新路径 # 其余的图像保存在训练集train中 else: image_path = cla_path + image new_path = 'data/train/' + cla copy(image_path, new_path) print("\r[{}] processing [{}/{}]".format(cla, index + 1, num), end="") # processing bar print() print("processing done!") vgg16的模型文件 import torch.nn as nn import torch class VGG(nn.Module): def __init__(self, features, num_classes=1000, init_weights=False): super(VGG, self).__init__() self.features = features self.classifier = nn.Sequential( nn.Dropout(p=0.5), nn.Linear(512*7*7, 2048), nn.ReLU(True), nn.Dropout(p=0.5), nn.Linear(2048, 2048), nn.ReLU(True), nn.Linear(2048, num_classes) ) if init_weights: self._initialize_weights() def forward(self, x): # N x 3 x 224 x 224 x = self.features(x) # N x 512 x 7 x 7 x = torch.flatten(x, start_dim=1) # N x 512*7*7 x = self.classifier(x) return x def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) # nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0) def make_features(cfg: list): layers = [] in_channels = 3 for v in cfg: if v == "M": layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) layers += [conv2d, nn.ReLU(True)] in_channels = v return nn.Sequential(*layers) cfgs = { 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], } def vgg(model_name="vgg16", **kwargs): try: cfg = cfgs[model_name] except: print("Warning: model number {} not in cfgs dict!".format(model_name)) exit(-1) model = VGG(make_features(cfg), **kwargs) return model 训练代码 import os import json import torch import torch.nn as nn from torchvision import transforms, datasets import torch.optim as optim from tqdm import tqdm from model import vgg def main(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("using {} device.".format(device)) data_transform = { "train": transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]), "val": transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])} data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path image_path = os.path.join(data_root, "code/fenlei", "data") # flower data set path assert os.path.exists(image_path), "{} path does not exist.".format(image_path) train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), transform=data_transform["train"]) train_num = len(train_dataset) # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4} flower_list = train_dataset.class_to_idx cla_dict = dict((val, key) for key, val in flower_list.items()) # write dict into json file json_str = json.dumps(cla_dict, indent=4) with open('class_indices.json', 'w') as json_file: json_file.write(json_str) batch_size = 4 nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers print('Using {} dataloader workers every process'.format(nw)) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=nw) validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"), transform=data_transform["val"]) val_num = len(validate_dataset) validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=batch_size, shuffle=False, num_workers=nw) print("using {} images for training, {} images for validation.".format(train_num, val_num)) # test_data_iter = iter(validate_loader) # test_image, test_label = test_data_iter.next() model_name = "vgg16" net = vgg(model_name=model_name, num_classes=3, init_weights=True) net.to(device) loss_function = nn.CrossEntropyLoss() optimizer = optim.Adam(net.parameters(), lr=0.0001) epochs = 30 best_acc = 0.0 save_path = './{}Net.pth'.format(model_name) train_steps = len(train_loader) for epoch in range(epochs): # train net.train() running_loss = 0.0 train_bar = tqdm(train_loader) for step, data in enumerate(train_bar): images, labels = data optimizer.zero_grad() outputs = net(images.to(device)) loss = loss_function(outputs, labels.to(device)) loss.backward() optimizer.step() # print statistics running_loss += loss.item() train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss) # validate net.eval() acc = 0.0 # accumulate accurate number / epoch with torch.no_grad(): #val_bar = tqdm(validate_loader, colour = 'green') for val_data in validate_loader: val_images, val_labels = val_data outputs = net(val_images.to(device)) predict_y = torch.max(outputs, dim=1)[1] acc += torch.eq(predict_y, val_labels.to(device)).sum().item() val_accurate = acc / val_num print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' % (epoch + 1, running_loss / train_steps, val_accurate)) if val_accurate > best_acc: best_acc = val_accurate torch.save(net.state_dict(), save_path) print('Finished Training') if __name__ == '__main__': main()

训练完之后就会得到一个自己的训练模型,可以用于测试。

测试代码 import torch from model import vgg from PIL import Image from torchvision import transforms import matplotlib.pyplot as plt import json # 预处理 data_transform = transforms.Compose( [transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # load image image_path = "F:\微信\WeChat Files\zhzwq66666\FileStorage\File/2021-02\盖帽/110.png" img = Image.open(image_path).convert('RGB') plt.imshow(img) # [N, C, H, W] img = data_transform(img) # expand batch dimension img = torch.unsqueeze(img, dim=0) # read class_indict try: json_file = open('./class_indices.json', 'r') class_indict = json.load(json_file) except Exception as e: print(e) exit(-1) # create model model = vgg(num_classes=3) # load model weights model_weight_path = "./vgg16Net.pth" model.load_state_dict(torch.load(model_weight_path)) # 关闭 Dropout model.eval() with torch.no_grad(): # predict class output = torch.squeeze(model(img)) # 将输出压缩,即压缩掉 batch 这个维度 predict = torch.softmax(output, dim=0) predict_cla = torch.argmax(predict).numpy() print(class_indict[str(predict_cla)], predict[predict_cla].item()) plt.show()


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3