GigaGAN效果堪比diffusion model扩散模型?重新认识学习一下生成对抗GAN网络

您所在的位置:网站首页 gan算法模型 GigaGAN效果堪比diffusion model扩散模型?重新认识学习一下生成对抗GAN网络

GigaGAN效果堪比diffusion model扩散模型?重新认识学习一下生成对抗GAN网络

#GigaGAN效果堪比diffusion model扩散模型?重新认识学习一下生成对抗GAN网络| 来源: 网络整理| 查看: 265

GigaGan

CVPR2023收录了一篇论文GigaGAN《Scaling up GANs for Text-to-Image Synthesis》。

在论文里面展示了 1B 参数 GigaGAN,实现了比稳定扩散 v1.5、DALL·E 2 和 Parti-750M 更低的 FID。

GigaGAN可以应用于真实图像,在 3.66 秒内合成 4k 分辨率的超高分辨率图像。同这一功能类似的有SRGAN。

GiGAN的图像超分辨率修复效果如下图:

GiGAN的图像超分辨率修复

此外也以从文本到图像模型的低分辨率输出生成 4K 图像,论文中也提供了文本生成的examples。

从展示的样例中,基本上和前段时间大热的diffusion model生成图像的能力不相上下。

GigaGAN改进了StyleGAN架构,GigaGAN 生成器由文本编码分支、样式映射网络、多尺度合成网络组成,并通过稳定的注意力和自适应内核选择进行增强。在文本编码分支中,首先使用预训练的 CLIP 模型和学习的注意力层 T 提取文本嵌入。嵌入被传递到样式映射网络 M 以生成样式向量 w,类似于 StyleGAN。现在,合成网络使用样式代码作为调制,使用文本嵌入作为注意力来生成图像金字塔。引入了样本自适应核选择,以根据输入文本条件自适应地选择卷积核。

鉴别器由两个分支组成,用于处理图像和文本调节。文本分支处理类似于生成器的文本。图像分支接收图像金字塔并对每个图像尺度进行独立预测。此外,预测是在下采样层的所有后续尺度上进行的。

可惜的是这一论文还没有公开代码,暂时我们还不能自己测试一下相关功能。但是看到GAN的网络在CVPR重新焕发生机,在代码公开之前我们抓紧时间再来学习实践一下GAN。

GAN网络的原理

关于GAN网络的数学原理,知乎已经有很多篇优秀的文章进行了讲解。可以搜索进行学习,这里不再赘述。

GAN网络的实践

GAN网络有SRGAN、DCGAN等经典算法可以进行学习,Github上也有很多样例进行学习,这里我们挑选star比较高的TensorLayer社区下的TensorLayerX来讲讲GAN代码的入门,他们的社区有简单的mnist_gan、以及经典的srgan、dcgan算法,并且提供训练脚本和已经训练好的参数权重进行体验。

mnist—gan

mnist gan是一个用来学习GAN算法最简单的入门样例。训练生成器用随机的输入数据来输出生成手写数字。生成器和辨别器都很简单,都仅仅只需要三层线性网络,用来了解GAN网络的训练方式以及loss设置最好不过。

生成器代码如下:

# We define generator network. class generator(Module): def __init__(self): super(generator, self).__init__() # Linear layer with 256 units, using ReLU for output. self.g_fc1 = Linear(out_features=256, in_features=100, act=tlx.nn.ReLU) self.g_fc2 = Linear(out_features=256, in_features=256, act=tlx.nn.ReLU) self.g_fc3 = Linear(out_features=784, in_features=256, act=tlx.nn.Tanh) def forward(self, x): out = self.g_fc1(x) out = self.g_fc2(out) out = self.g_fc3(out) return

辨别器代码如下:

# We define discriminator network. class discriminator(Module): def __init__(self): super(discriminator, self).__init__() # Linear layer with 256 units, using ReLU for output. self.d_fc1 = Linear(out_features=256, in_features=784, act=tlx.LeakyReLU) self.d_fc2 = Linear(out_features=256, in_features=256, act=tlx.LeakyReLU) self.d_fc3 = Linear(out_features=1, in_features=256, act=tlx.Sigmoid) def forward(self, x): out = self.d_fc1(x) out = self.d_fc2(out) out = self.d_fc3(out) return out

整个训练脚本非常简单,包含数据加载,训练循环,生成图像展示整体不过161行代码,很适合新手用来入门。而且mnist gan的代码基于TensorLayerx编写,这是一个可以支持代码运行在tensorflow、pytorch、paddle、mindspore的框架,因此也不用担心你要重新熟悉别的框架哦!你的电脑上有哪个框架,就指定Tensorlayerx运行在哪个后端框架就行,比如你希望脚本运行在pytorch上,只需要一行代码就可以,其他代码不需要改动。

import os # os.environ['TL_BACKEND'] = 'paddle' # os.environ['TL_BACKEND'] = 'tensorflow' # os.environ['TL_BACKEND'] = 'mindspore' os.environ['TL_BACKEND'] = 'torch'

强烈建议大家动手去运行一下代码,自己运行了才有收获!

DCGAN

如果觉得mnist gan比较简单,可以继续学习这个社区下的

网络算法巩固一下GAN网络的实践。

DCGAN的网络将深度卷积神经网络CNN与生成对抗网络GAN结合用于无监督学习领域, 实现了更为强大的生成模型,并且在celebA数据集上进行了训练。

相比于mnist gan 的toy 模型结构,dcgan的生成器和辨别器要更加复杂一点,使用指定步长的卷积层代替池化层,并且加入了batch norm层,移除了全连接层,生成器除去输出层采用Tanh外,全部使用ReLU作为激活函数,判别器所有层都使用LeakyReLU作为激活函数。

和mnist gan的训练脚本一样,dcgan的程序脚本也基于TensorLayerx编写,可以支持代码运行在tensorflow、pytorch、paddle、mindspore的框架上, 整体代码仅仅200多行,代码清晰文件结构直白。尽管DCGAN稍微复杂了一点,但我相信时至今日大家应该能够很快的掌握这个模型,快去运行起来吧。

SRGAN

mnist_gan、dcgan学习完,相信你一定对GAN网络有了基本的认识和实践。接下来再来看看srgan吧,

作为github上srgan网络里面star最多的实现仓库,还是有学习价值。

SRGAN(Super-Resolution Generative Adversarial Network)即超分辨率GAN,是Christian Ledig等人于16年9月提出的一种对抗神经网络。利用卷积神经网络实现单影像的超分辨率,其瓶颈仍在于如何恢复图像的细微纹理信息。对于GAN而言,将一组随机噪声输入到生成器中,生成的图像质量往往较差。因此,作者提出了SRGAN,并定义一个loss函数以驱动模型,SRGAN最终可以生成一幅原始影像扩大4倍的高分辨率影像。

虽然对比一下不如GigaGan的效果,srgan网络在刚推出的时候效果也相当惊艳。由于srgan的代码较为复杂,希望大家可以移步github仓库自行查看。srgan的程序脚本也基于TensorLayerx编写,可以支持代码运行在tensorflow、pytorch、paddle、mindspore的框架上。而且该仓库也提供了训练好的权重帮助大家体验推理性能哦!



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3