用 PyTorch 实现一个基本 GAN 网络学习正态分布

您所在的位置:网站首页 pytorch搭建GAN 用 PyTorch 实现一个基本 GAN 网络学习正态分布

用 PyTorch 实现一个基本 GAN 网络学习正态分布

2024-05-28 17:11| 来源: 网络整理| 查看: 265

这篇文章将用 PyTorch 实现一个基本的生成对抗网络(Generative Adversarial Network, GAN),来学习一个正态分布。代码提供 Jupiter Notebook,地址见文末。

首先我们导入一些必备的库:

1234import torchimport torch.nn as nnimport torch.optim as optimfrom torch.distributions.normal import Normal

文章目录

1 定义我们要学习的正态分布2 定义生成网络(Generator)3 定义对抗网络(Adversarial)4 定义数据输入方式5 学习率6 生成器(Generator)7 鉴别器(Discriminator)8 搭建网络9 训练10 结果11 代码地址定义我们要学习的正态分布

我们定一个正态分布,它的均值和标准差如下:

123data_mean = 3.0data_stddev = 0.4Series_Length = 30

定义生成网络(Generator)

我们的生成网络接收一些随机输入,按照上面的定义生成正态分布。你可以在代码里改变这些变量的值来看它们对最终结果的影响:

123g_input_size = 20    g_hidden_size = 150  g_output_size = Series_Length

定义对抗网络(Adversarial)

我们的对抗网络输出如下:

True(1.0) 如果输入的数据符合定义的正态分布False(0.0) 如果输入的数据不符合定义的正态分布

123d_input_size = Series_Lengthd_hidden_size = 75   d_output_size = 1

定义数据输入方式

1234d_minibatch_size = 15 g_minibatch_size = 10num_epochs = 5000print_interval = 1000

学习率

下面的学习率你也可以试着变一下做做实验,如果太小会影响收敛。

12d_learning_rate = 3e-3g_learning_rate = 8e-3

下面的两个函数一个可以得到真正的分布,一个可以得到噪声。真正的分布用来训练 Discriminator,噪声用来作为 Generator 的输入。

123456789def get_real_sampler(mu, sigma):    dist = Normal( mu, sigma )    return lambda m, n: dist.sample( (m, n) ).requires_grad_() def get_noise_sampler():    return lambda m, n: torch.rand(m, n).requires_grad_()  # Uniform-dist data into generator, _NOT_ Gaussian actual_data = get_real_sampler( data_mean, data_stddev )noise_data  = get_noise_sampler()

生成器(Generator)

生成器用来输出符合我们想要的正态分布的均值。很简单的一个 4 层网络。

1234567891011class Generator(nn.Module):    def __init__(self, input_size, hidden_size, output_size):        super(Generator, self).__init__()        self.map1 = nn.Linear(input_size, hidden_size)        self.map2 = nn.Linear(hidden_size, hidden_size)        self.map3 = nn.Linear(hidden_size, output_size)        self.xfer = torch.nn.SELU()    def forward(self, x):        x = self.xfer( self.map1(x) )        x = self.xfer( self.map2(x) )        return self.xfer( self.map3( x ) )

鉴别器(Discriminator)

非常简单的 Linear 模型,返回 True 或者 False。

123456789101112class Discriminator(nn.Module):    def __init__(self, input_size, hidden_size, output_size):        super(Discriminator, self).__init__()        self.map1 = nn.Linear(input_size, hidden_size)        self.map2 = nn.Linear(hidden_size, hidden_size)        self.map3 = nn.Linear(hidden_size, output_size)        self.elu = torch.nn.ELU()     def forward(self, x):        x = self.elu(self.map1(x))        x = self.elu(self.map2(x))        return torch.sigmoid( self.map3(x) )

搭建网络

我们使用 BCE 损失函数,SGD 优化函数。

123456G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size)D = Discriminator(input_size=d_input_size, hidden_size=d_hidden_size, output_size=d_output_size) criterion = nn.BCELoss()d_optimizer = optim.SGD(D.parameters(), lr=d_learning_rate ) g_optimizer = optim.SGD(G.parameters(), lr=g_learning_rate )

训练

1234567891011121314151617181920212223242526272829303132333435363738def train_D_on_actual() :    real_data = actual_data( d_minibatch_size, d_input_size )    real_decision = D( real_data )    real_error = criterion( real_decision, torch.ones( d_minibatch_size, 1 ))  # ones = true    real_error.backward() def train_D_on_generated() :    noise = noise_data( d_minibatch_size, g_input_size )    fake_data = G( noise )     fake_decision = D( fake_data )    fake_error = criterion( fake_decision, torch.zeros( d_minibatch_size, 1 ))  # zeros = fake    fake_error.backward() def train_G():    noise = noise_data( g_minibatch_size, g_input_size )    fake_data = G( noise )    fake_decision = D( fake_data )    error = criterion( fake_decision, torch.ones( g_minibatch_size, 1 ) )     error.backward()    return error.item(), fake_data losses = []for epoch in range(num_epochs):    D.zero_grad()        train_D_on_actual()        train_D_on_generated()    d_optimizer.step()        G.zero_grad()    loss,generated = train_G()    g_optimizer.step()        losses.append( loss )    if( epoch % print_interval) == (print_interval-1) :        print( "Epoch %6d. Loss %5.3f" % ( epoch 1, loss ) )        print( "Training complete" )

结果

训练完成后我们展示一些结果:

1234567891011import matplotlib.pyplot as pltdef draw( data ) :        plt.figure()    d = data.tolist() if isinstance(data, torch.Tensor ) else data    plt.plot( d )     plt.show() d = torch.empty( generated.size(0), 53 ) for i in range( 0, d.size(0) ) :    d[i] = torch.histc( generated[i], min=0, max=5, bins=53 )draw( d.t() )

代码地址

代码见 https://github.com/rcorbish/pytorch-notebooks。

本站微信群、QQ群(三群号 726282629):



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3