【Pytorch】DCGAN实战(三):二次元动漫头像生成

您所在的位置:网站首页 动漫二次元漫画图片人物 【Pytorch】DCGAN实战(三):二次元动漫头像生成

【Pytorch】DCGAN实战(三):二次元动漫头像生成

2024-06-15 03:50| 来源: 网络整理| 查看: 265

文章目录 1.实现效果2.环境配置2.1Python2.2Pytorch、CUDA2.3Python IDE 3.具体实现3.1数据预处理(data.py)(1)导入包(2)定义数据类 3.2模型Generator,Discriminator,权重初始化(model.py)(1)导入包(2)Generator(3)Discriminator(4)权重初始化 3.3网络训练(net.py)(1)导入包(2)创建类 3.4 主函数(main.py)(1)导入文件(2)定义超参数(3)实例化(4)进行训练 4.训练过程4.1 Generator和Discriminator的Loss损失曲线图4.2 D(x)和D(G(z))曲线图4.3最终生成结果图 5.完整代码6.引用参考7.问题反馈

1.实现效果

使用DCGAN训练faces数据集,最终实现生成二次元动漫头像。 最后虽然生成了动漫头像,但是一些细节还是和真实的图像差别较大,比如说眼睛大小,眼睛颜色等。 之后我会将MINIST数据集、Oxford17数据集、以及faces数据集在训练过程中不同轮次的输出结果做一个总结。 生成二次元动漫头像的程序依然是沿用data.py、model.py、net.py、main.py但具体的编程的细节呢有所改变。 之前MINIST以及Oxford17数据集的程序 这里: 【Pytorch】DCGAN实战(一):基于MINIST数据集的手写数字生成 【Pytorch】DCGAN实战(二):基于Oxord17的鲜花图像生成

2.环境配置 2.1Python

Python版本为3.7

2.2Pytorch、CUDA

在这里不详细介绍了,网上有很多的安装教程,小伙伴们自行查找吧!

2.3Python IDE

Pycharm

3.具体实现

整体分为4个文件:data.py、model.py、net.py、main.py

3.1数据预处理(data.py) (1)导入包 from torch.utils.data import DataLoader from torchvision import utils, datasets, transforms (2)定义数据类 class ReadData(): def __init__(self,data_path,image_size=64): self.root=data_path self.image_size=image_size self.dataset=self.getdataset() def getdataset(self): #3.dataset dataset = datasets.ImageFolder(root=self.root, transform=transforms.Compose([ transforms.Resize(self.image_size), transforms.CenterCrop(self.image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) print(f'Total Size of Dataset: {len(dataset)}') return dataset def getdataloader(self,batch_size=128): dataloader = DataLoader( self.dataset, batch_size=batch_size, shuffle=True, num_workers=0) return dataloader 3.2模型Generator,Discriminator,权重初始化(model.py) (1)导入包 import torch.nn as nn (2)Generator class Generator(nn.Module): def __init__(self, nz,ngf,nc): super(Generator, self).__init__() self.nz = nz self.ngf = ngf self.nc=nc self.main = nn.Sequential( # input is Z, going into a convolution nn.ConvTranspose2d(self.nz, self.ngf * 8, 4, 1, 0, bias=False), nn.BatchNorm2d(self.ngf * 8), nn.ReLU(True), # state size. (ngf*8) x 4 x 4 nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(self.ngf * 4), nn.ReLU(True), # state size. (ngf*4) x 8 x 8 nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(self.ngf * 2), nn.ReLU(True), # state size. (ngf*2) x 16 x 16 nn.ConvTranspose2d(self.ngf * 2, self.ngf, 4, 2, 1, bias=False), nn.BatchNorm2d(self.ngf), nn.ReLU(True), # state size. (ngf) x 32 x 32 nn.ConvTranspose2d(self.ngf, self.nc, 4, 2, 1, bias=False), nn.Tanh() # state size. (nc) x 64 x 64 ) def forward(self, input): return self.main(input) (3)Discriminator class Discriminator(nn.Module): def __init__(self, ndf,nc): super(Discriminator, self).__init__() self.ndf=ndf self.nc=nc self.main = nn.Sequential( # input is (nc) x 64 x 64 nn.Conv2d(self.nc, self.ndf, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), # state size. (ndf) x 32 x 32 nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(self.ndf * 2), nn.LeakyReLU(0.2, inplace=True), # state size. (ndf*2) x 16 x 16 nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(self.ndf * 4), nn.LeakyReLU(0.2, inplace=True), # state size. (ndf*4) x 8 x 8 nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1, bias=False), nn.BatchNorm2d(self.ndf * 8), nn.LeakyReLU(0.2, inplace=True), # state size. (ndf*8) x 4 x 4 nn.Conv2d(self.ndf * 8, 1, 4, 1, 0, bias=False), # state size. (1) x 1 x 1 nn.Sigmoid() ) def forward(self, input): return self.main(input) (4)权重初始化 def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find('BatchNorm') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0) 3.3网络训练(net.py) (1)导入包 import torch import torch.nn as nn from torchvision import utils, datasets, transforms import time import matplotlib.pyplot as plt import matplotlib.animation as animation from IPython.display import HTML import os (2)创建类 class DCGAN(): def __init__(self,lr,beta1,nz, batch_size,num_showimage,device, model_save_path,figure_save_path,generator, discriminator, data_loader,): self.real_label=1 self.fake_label=0 self.nz=nz self.batch_size=batch_size self.num_showimage=num_showimage self.device = device self.model_save_path=model_save_path self.figure_save_path=figure_save_path self.G = generator.to(device) self.D = discriminator.to(device) self.opt_G=torch.optim.Adam(self.G.parameters(), lr=lr, betas=(beta1, 0.999)) self.opt_D = torch.optim.Adam(self.D.parameters(), lr=lr, betas=(beta1, 0.999)) self.criterion = nn.BCELoss().to(device) self.dataloader=data_loader self.fixed_noise = torch.randn(self.num_showimage, nz, 1, 1, device=device) self.img_list = [] self.G_loss_list = [] self.D_loss_list = [] self.D_x_list = [] self.D_z_list = [] def train(self,num_epochs): loss_tep = 10 G_loss=0 D_loss=0 print("Starting Training Loop...") # For each epoch for epoch in range(num_epochs): #**********计时********************* beg_time = time.time() # For each batch in the dataloader for i, data in enumerate(self.dataloader): ############################ # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) ########################### x = data[0].to(self.device) b_size = x.size(0) lbx = torch.full((b_size,), self.real_label, dtype=torch.float, device=self.device) D_x = self.D(x).view(-1) LossD_x = self.criterion(D_x, lbx) D_x_item = D_x.mean().item() # print("log(D(x))") z = torch.randn(b_size, self.nz, 1, 1, device=self.device) gz = self.G(z) lbz1 = torch.full((b_size,), self.fake_label, dtype=torch.float, device=self.device) D_gz1 = self.D(gz.detach()).view(-1) LossD_gz1 = self.criterion(D_gz1, lbz1) D_gz1_item = D_gz1.mean().item() # print("log(1 - D(G(z)))") LossD = LossD_x + LossD_gz1 # print("log(D(x)) + log(1 - D(G(z)))") self.opt_D.zero_grad() LossD.backward() self.opt_D.step() # print("update LossD") D_loss+=LossD ############################ # (2) Update G network: maximize log(D(G(z))) ########################### lbz2 = torch.full((b_size,), self.real_label, dtype=torch.float, device=self.device) # fake labels are real for generator cost D_gz2 = self.D(gz).view(-1) D_gz2_item = D_gz2.mean().item() LossG = self.criterion(D_gz2, lbz2) # print("log(D(G(z)))") self.opt_G.zero_grad() LossG.backward() self.opt_G.step() # print("update LossG") G_loss+=LossG end_time = time.time() # **********计时********************* run_time = round(end_time - beg_time) # print('lalala') print( f'Epoch: [{epoch + 1:0>{len(str(num_epochs))}}/{num_epochs}]', f'Step: [{i + 1:0>{len(str(len(self.dataloader)))}}/{len(self.dataloader)}]', f'Loss-D: {LossD.item():.4f}', f'Loss-G: {LossG.item():.4f}', f'D(x): {D_x_item:.4f}', f'D(G(z)): [{D_gz1_item:.4f}/{D_gz2_item:.4f}]', f'Time: {run_time}s', end='\r\n' ) # print("lalalal2") # Save Losses for plotting later self.G_loss_list.append(LossG.item()) self.D_loss_list.append(LossD.item()) # Save D(X) and D(G(z)) for plotting later self.D_x_list.append(D_x_item) self.D_z_list.append(D_gz2_item) # # Save the Best Model # if LossG < loss_tep: # torch.save(self.G.state_dict(), 'model.pt') # loss_tep = LossG if not os.path.exists(self.model_save_path): os.makedirs(self.model_save_path) torch.save(self.D.state_dict(), self.model_save_path + 'disc_{}.pth'.format(epoch)) torch.save(self.G.state_dict(), self.model_save_path + 'gen_{}.pth'.format(epoch)) # Check how the generator is doing by saving G's output on fixed_noise with torch.no_grad(): fake = self.G(self.fixed_noise).detach().cpu() self.img_list.append(utils.make_grid(fake * 0.5 + 0.5, nrow=10)) print() if not os.path.exists(self.figure_save_path): os.makedirs(self.figure_save_path) plt.figure(1,figsize=(8, 4)) plt.title("Generator and Discriminator Loss During Training") plt.plot(self.G_loss_list[::10], label="G") plt.plot(self.D_loss_list[::10], label="D") plt.xlabel("iterations") plt.ylabel("Loss") plt.axhline(y=0, label="0", c="g") # asymptote plt.legend() plt.savefig(self.figure_save_path + str(num_epochs) + 'epochs_' + 'loss.jpg', bbox_inches='tight') plt.figure(2,figsize=(8, 4)) plt.title("D(x) and D(G(z)) During Training") plt.plot(self.D_x_list[::10], label="D(x)") plt.plot(self.D_z_list[::10], label="D(G(z))") plt.xlabel("iterations") plt.ylabel("Probability") plt.axhline(y=0.5, label="0.5", c="g") # asymptote plt.legend() plt.savefig(self.figure_save_path + str(num_epochs) + 'epochs_' + 'D(x)D(G(z)).jpg', bbox_inches='tight') fig = plt.figure(3,figsize=(5, 5)) plt.axis("off") ims = [[plt.imshow(item.permute(1, 2, 0), animated=True)] for item in self.img_list] ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True) HTML(ani.to_jshtml()) # ani.to_html5_video() ani.save(self.figure_save_path + str(num_epochs) + 'epochs_' + 'generation.gif') plt.figure(4,figsize=(8, 4)) # Plot the real images plt.subplot(1, 2, 1) plt.axis("off") plt.title("Real Images") real = next(iter(self.dataloader)) # real[0]image,real[1]label plt.imshow(utils.make_grid(real[0][:self.num_showimage] * 0.5 + 0.5, nrow=10).permute(1, 2, 0)) # Load the Best Generative Model # self.G.load_state_dict( # torch.load(self.model_save_path + 'disc_{}.pth'.format(epoch), map_location=torch.device(self.device))) self.G.eval() # Generate the Fake Images with torch.no_grad(): fake = self.G(self.fixed_noise).cpu() # Plot the fake images plt.subplot(1, 2, 2) plt.axis("off") plt.title("Fake Images") fake = utils.make_grid(fake[:self.num_showimage] * 0.5 + 0.5, nrow=10).permute(1, 2, 0) plt.imshow(fake) # Save the comparation result plt.savefig(self.figure_save_path + str(num_epochs) + 'epochs_' + 'result.jpg', bbox_inches='tight') plt.show() def test(self,epoch): # Size of the Figure plt.figure(figsize=(8, 4)) # Plot the real images plt.subplot(1, 2, 1) plt.axis("off") plt.title("Real Images") real = next(iter(self.dataloader))#real[0]image,real[1]label plt.imshow(utils.make_grid(real[0][:self.num_showimage] * 0.5 + 0.5, nrow=10).permute(1, 2, 0)) # Load the Best Generative Model self.G.load_state_dict(torch.load(self.model_save_path + 'disc_{}.pth'.format(epoch), map_location=torch.device(self.device))) self.G.eval() # Generate the Fake Images with torch.no_grad(): fake = self.G(self.fixed_noise.to(self.device)) # Plot the fake images plt.subplot(1, 2, 2) plt.axis("off") plt.title("Fake Images") fake = utils.make_grid(fake * 0.5 + 0.5, nrow=10) plt.imshow(fake.permute(1, 2, 0)) # Save the comparation result plt.savefig(self.figure_save_path+'result.jpg', bbox_inches='tight') plt.show() 3.4 主函数(main.py) (1)导入文件 from data import ReadData from model import Discriminator, Generator, weights_init from net import DCGAN import torch (2)定义超参数 ngpu=1 ngf=64 ndf=64 nc=3 nz=100 lr=0.003 beta1=0.5 batch_size=100 num_showimage=100 data_path="./oxford17_class" model_save_path="./models/" figure_save_path="./figures/" device = torch.device('cuda:0' if (torch.cuda.is_available() and ngpu > 0) else 'cpu') (3)实例化 dataset=ReadData(data_path) dataloader=dataset.getdataloader(batch_size=batch_size) G = Generator(nz,ngf,nc).apply(weights_init) print(G) D = Discriminator(ndf,nc).apply(weights_init) print(D) dcgan=DCGAN( lr,beta1,nz,batch_size,num_showimage,device, model_save_path,figure_save_path,G, D, dataloader) (4)进行训练 dcgan.train(num_epochs=20) 4.训练过程 4.1 Generator和Discriminator的Loss损失曲线图

训练过程中Generator和Discriminator的Loss曲线图(以200个epoch为例): Generator和Discriminator的Loss损失曲线图

4.2 D(x)和D(G(z))曲线图

训练过程中Discriminator输出(以200个epoch为例): D(x)和D(G(z))曲线图

4.3最终生成结果图

训练结束后生成图片(以5个epoch为例): 最终生成结果图

5.完整代码

链接:https://pan.baidu.com/s/15J6sZL3rCPLm2jZFEuyzNw 提取码:DGAN

6.引用参考

https://blog.csdn.net/qq_42951560/article/details/112199229 https://blog.csdn.net/qq_42951560/article/details/110308336

7.问题反馈

如果运行有问题,欢迎给我私信留言!



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3