用 PyTorch 实现一个基本 GAN 网络学习正态分布 |
您所在的位置:网站首页 › pytorch搭建GAN › 用 PyTorch 实现一个基本 GAN 网络学习正态分布 |
这篇文章将用 PyTorch 实现一个基本的生成对抗网络(Generative Adversarial Network, GAN),来学习一个正态分布。代码提供 Jupiter Notebook,地址见文末。 首先我们导入一些必备的库: 1234import torchimport torch.nn as nnimport torch.optim as optimfrom torch.distributions.normal import Normal文章目录 1 定义我们要学习的正态分布2 定义生成网络(Generator)3 定义对抗网络(Adversarial)4 定义数据输入方式5 学习率6 生成器(Generator)7 鉴别器(Discriminator)8 搭建网络9 训练10 结果11 代码地址定义我们要学习的正态分布我们定一个正态分布,它的均值和标准差如下: 123data_mean = 3.0data_stddev = 0.4Series_Length = 30定义生成网络(Generator)我们的生成网络接收一些随机输入,按照上面的定义生成正态分布。你可以在代码里改变这些变量的值来看它们对最终结果的影响: 123g_input_size = 20 g_hidden_size = 150 g_output_size = Series_Length定义对抗网络(Adversarial)我们的对抗网络输出如下: True(1.0) 如果输入的数据符合定义的正态分布False(0.0) 如果输入的数据不符合定义的正态分布123d_input_size = Series_Lengthd_hidden_size = 75 d_output_size = 1定义数据输入方式1234d_minibatch_size = 15 g_minibatch_size = 10num_epochs = 5000print_interval = 1000学习率下面的学习率你也可以试着变一下做做实验,如果太小会影响收敛。 12d_learning_rate = 3e-3g_learning_rate = 8e-3下面的两个函数一个可以得到真正的分布,一个可以得到噪声。真正的分布用来训练 Discriminator,噪声用来作为 Generator 的输入。 123456789def get_real_sampler(mu, sigma): dist = Normal( mu, sigma ) return lambda m, n: dist.sample( (m, n) ).requires_grad_() def get_noise_sampler(): return lambda m, n: torch.rand(m, n).requires_grad_() # Uniform-dist data into generator, _NOT_ Gaussian actual_data = get_real_sampler( data_mean, data_stddev )noise_data = get_noise_sampler() 生成器(Generator)生成器用来输出符合我们想要的正态分布的均值。很简单的一个 4 层网络。 1234567891011class Generator(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(Generator, self).__init__() self.map1 = nn.Linear(input_size, hidden_size) self.map2 = nn.Linear(hidden_size, hidden_size) self.map3 = nn.Linear(hidden_size, output_size) self.xfer = torch.nn.SELU() def forward(self, x): x = self.xfer( self.map1(x) ) x = self.xfer( self.map2(x) ) return self.xfer( self.map3( x ) )鉴别器(Discriminator)非常简单的 Linear 模型,返回 True 或者 False。 123456789101112class Discriminator(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(Discriminator, self).__init__() self.map1 = nn.Linear(input_size, hidden_size) self.map2 = nn.Linear(hidden_size, hidden_size) self.map3 = nn.Linear(hidden_size, output_size) self.elu = torch.nn.ELU() def forward(self, x): x = self.elu(self.map1(x)) x = self.elu(self.map2(x)) return torch.sigmoid( self.map3(x) )搭建网络我们使用 BCE 损失函数,SGD 优化函数。 123456G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size)D = Discriminator(input_size=d_input_size, hidden_size=d_hidden_size, output_size=d_output_size) criterion = nn.BCELoss()d_optimizer = optim.SGD(D.parameters(), lr=d_learning_rate ) g_optimizer = optim.SGD(G.parameters(), lr=g_learning_rate )训练1234567891011121314151617181920212223242526272829303132333435363738def train_D_on_actual() : real_data = actual_data( d_minibatch_size, d_input_size ) real_decision = D( real_data ) real_error = criterion( real_decision, torch.ones( d_minibatch_size, 1 )) # ones = true real_error.backward() def train_D_on_generated() : noise = noise_data( d_minibatch_size, g_input_size ) fake_data = G( noise ) fake_decision = D( fake_data ) fake_error = criterion( fake_decision, torch.zeros( d_minibatch_size, 1 )) # zeros = fake fake_error.backward() def train_G(): noise = noise_data( g_minibatch_size, g_input_size ) fake_data = G( noise ) fake_decision = D( fake_data ) error = criterion( fake_decision, torch.ones( g_minibatch_size, 1 ) ) error.backward() return error.item(), fake_data losses = []for epoch in range(num_epochs): D.zero_grad() train_D_on_actual() train_D_on_generated() d_optimizer.step() G.zero_grad() loss,generated = train_G() g_optimizer.step() losses.append( loss ) if( epoch % print_interval) == (print_interval-1) : print( "Epoch %6d. Loss %5.3f" % ( epoch 1, loss ) ) print( "Training complete" )结果训练完成后我们展示一些结果: 1234567891011import matplotlib.pyplot as pltdef draw( data ) : plt.figure() d = data.tolist() if isinstance(data, torch.Tensor ) else data plt.plot( d ) plt.show() d = torch.empty( generated.size(0), 53 ) for i in range( 0, d.size(0) ) : d[i] = torch.histc( generated[i], min=0, max=5, bins=53 )draw( d.t() )代码见 https://github.com/rcorbish/pytorch-notebooks。 本站微信群、QQ群(三群号 726282629): ![]() |
今日新闻 |
推荐新闻 |
CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3 |