从零实现一个小小小的扩散模型

您所在的位置:网站首页 cs16aug模型 从零实现一个小小小的扩散模型

从零实现一个小小小的扩散模型

2023-04-10 03:27| 来源: 网络整理| 查看: 265

安装

调包之前确认你已经安装了相应的库,需要pytorch、matplotlib。

然后再安装diffusers

pip install -q diffusers 复制代码 数据 import torch import torchvision from torch import nn from torch.nn import functional as F from torch.utils.data import DataLoader from diffusers import DDPMScheduler, UNet2DModel from matplotlib import pyplot as plt device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f'Using device: {device}') 复制代码

如果你用GPU的话现在这里输出应该是:

Using device: cuda

这里就用最简单mnist数据集,当然如果你想换别的数据集自行更换就OK。

pytorch传统艺能 用DataLoader加载数据

dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor()) train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True) # pytorch传统艺能 用DataLoader加载数据 复制代码

我们使用DataLoader()读取数据后,用next(iter(data_iter))来返回批量数据,而不能使用 next(data_iter),原理就在这儿。 使用迭代器来返回批量数据,可在大量数据情况下,实现小批量循环迭代式的读取,避免了内存不足的问题。

x, y = next(iter(train_dataloader)) print('Input shape:', x.shape) print('Labels:', y) plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys'); 复制代码

可以看到输出

Input shape: torch.Size([8, 1, 28, 28])

Labels: tensor([1, 9, 7, 3, 5, 2, 1, 4])

image.png

每张图像都是一个28x28像素的灰度手写数字,像素值范围从0到1。我们上边设置一个batch_size大小为8,所以这里一组图出来8个。

注意:因为这batch取的是随机的,所以你每次运行显示的八张图是不一样的,和我结果不一样没关系。

加噪过程

假设你还没有阅读过任何扩散模型论文,现在告诉你,扩散模型的前向过程需要给图片加噪声。现在给你提供一个简单的加噪方式:

noise = torch.rand_like(x)

noisy_x = (1-amount)*x + amount*noise

如果amount = 0,则返回输入图像不做任何更改。如果amount = 1,则返回纯噪声。

通过这种方式混合输入和噪声,我们可以保持输出在相同的范围内(0到1)。并且这样比较容易实现。写代码时候要注意Tensor的形状,以免pytorch的广播机制被破坏。

def corrupt(x, amount): # 根据输入的amount 对 图像加噪 noise = torch.rand_like(x) amount = amount.view(-1, 1, 1, 1) # 使用.view方法修改形状 return x*(1-amount) + noise*amount 复制代码

让我们看看这个加噪代码的效果如何:

# 显示一下输入图像 fig, axs = plt.subplots(2, 1, figsize=(12, 5)) axs[0].set_title('Input data') axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap='Greys') # 为图像添加噪声 amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption noised_x = corrupt(x, amount) # 画出添加噪声之后的图像 axs[1].set_title('Corrupted data (-- amount increases -->)') axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap='Greys'); 复制代码

torch.linspace(0, 1, x.shape[0]) 就是生成0-1之间的数,生成数量为一个batch,也就是生成一个等差数列作为噪声添加给一个batch的图像。

image.png

模型

我们希望我们的模型可以接受一个28*28像素大小的带噪声的输入,并输出一个相同形状的去噪结果。在常规的扩散模型中使用的是U-net。不了解的可以看这个文章:浅谈语义分割网络U-Net。该模型最初是为医学图像分割任务而发明的,一个U-Net由一个“压缩路径”和一个“扩展路径”组成,数据通过“压缩路径”被压缩,然后通过“扩展路径”恢复到原始尺寸(类似于自动编码器),但还包括skip connection,允许数据在不同层次上传递信息和梯度。

一些U-Net在每个阶段都有复杂的组成模块,但我们今天只是简单实现一个扩散模型,所以也不搞什么复杂的U-net结构 了,我们在这里构建一个最最简单的U-net示例,模型可以接收一个单通道图像,并通过压缩路径上的三个卷积层(图表和代码中的down_layers)和扩展路径上的三个卷积层,在下行和上行层之间使用跳过连接。模型中使用最大池化进行下采样,使用nn.Upsample进行上采样。下图是大致的架构,显示每个层输出的通道数:

image.png

代码里用的激活函数是torch.nn.SiLU():

silu⁡(x)=x∗sigmoid⁡(x)=x11+e−x\operatorname{silu}(x)=x * \operatorname{sigmoid}(x)=x \frac{1}{1+e^{-x}}silu(x)=x∗sigmoid(x)=x1+e−x1​

在压缩路径中,数据通过三个卷积层(存储在self.down_layers中)和最大池化层进行下采样。在每层之后都应用激活函数(存储在self.act中)。对于前两个压缩层,它们的输出还被存储在h列表中,以便在扩展路径中使用它们进行skip connection。在扩展路径中,数据通过三个卷积层(存储在self.up_layers中)进行上采样,并执行skip connection。在每层之后,也要应用激活函数。注意,在第一个上行层之前没有跳跃连接。

class BasicUNet(nn.Module): # 简易版的U-net def __init__(self, in_channels=1, out_channels=1): super().__init__() self.down_layers = torch.nn.ModuleList([ nn.Conv2d(in_channels, 32, kernel_size=5, padding=2), nn.Conv2d(32, 64, kernel_size=5, padding=2), nn.Conv2d(64, 64, kernel_size=5, padding=2), ]) self.up_layers = torch.nn.ModuleList([ nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.Conv2d(64, 32, kernel_size=5, padding=2), nn.Conv2d(32, out_channels, kernel_size=5, padding=2), ]) # 激活函数 self.act = nn.SiLU() self.downscale = nn.MaxPool2d(2) self.upscale = nn.Upsample(scale_factor=2) def forward(self, x): h = [] for i, l in enumerate(self.down_layers): x = self.act(l(x)) if i < 2: h.append(x) x = self.downscale(x) for i, l in enumerate(self.up_layers): if i > 0: x = self.upscale(x) x += h.pop() x = self.act(l(x)) return x 复制代码

验证一下输入输出是否保持同维度。

net = BasicUNet() x = torch.rand(8, 1, 28, 28) net(x).shape 复制代码

这里我们可以看到输出是:

torch.Size([8, 1, 28, 28])

训练

说一下我们训练阶段的设定。先想一下我们这个模型要做什么:

丢给它带噪声的图片,模型应该生成其去噪结果

所以给定带噪声的noisy_x,模型要努力去恢复x。

这里我们使用均方误差计算模型输出和原始图像的差异。

接下来就是训练模型部分的代码:

拿到一个batch的数据

破坏数据模拟前向前向加噪

将破坏后的数据放入模型中

将模型输出结果和原始图像进行比较并计算loss

更新模型参数

# Dataloader 加载数据 batch_size = 128 train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # 我们要训练多少轮 n_epochs = 10 # 创建网络,将模型丢到GPU上(如果有GPU的话) net = BasicUNet() net.to(device) # 损失函数是用的MSE loss loss_fn = nn.MSELoss() # 优化器使用的Adam opt = torch.optim.Adam(net.parameters(), lr=1e-3) # 记录损失 losses = [] # 训练循环 for epoch in range(n_epochs): for x, y in train_dataloader: # 准备好输入数据和加噪数据 # 把数据放到GPU上(如果你有的话) x = x.to(device) #设定随机噪声 noise_amount = torch.rand(x.shape[0]).to(device)   # 处理x,获得加噪之后的样本noisy_x noisy_x = corrupt(x, noise_amount) # 获取模型输出结果 pred = net(noisy_x) # 计算loss loss = loss_fn(pred, x) # How close is the output to the true 'clean' x? # 反向传播更新模型参数 opt.zero_grad() loss.backward() opt.step() # 存储loss记录 losses.append(loss.item()) # 输出每轮训练的loss的平均值 avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader) print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}') # 画一下loss plt.plot(losses) plt.ylim(0, 0.1); 复制代码

可以看到输出结果如下:

Finished epoch 0. Average loss for this epoch: 0.025983

Finished epoch 1. Average loss for this epoch: 0.020247

Finished epoch 2. Average loss for this epoch: 0.018660

Finished epoch 3. Average loss for this epoch: 0.017662

Finished epoch 4. Average loss for this epoch: 0.016999

Finished epoch 5. Average loss for this epoch: 0.016730

Finished epoch 6. Average loss for this epoch: 0.016610

Finished epoch 7. Average loss for this epoch: 0.016287

Finished epoch 8. Average loss for this epoch: 0.016084

Finished epoch 9. Average loss for this epoch: 0.015731

image.png

模型训练咋样了?

带兄弟们看一下效果嗷:

# 取一组图像 x, y = next(iter(train_dataloader)) x = x[:8] # Only using the first 8 for easy plotting # 用我们前边给八张图加噪的那个方法,看看模型对不同程度的噪声的回复情况 amount = torch.linspace(0, 1, x.shape[0]) noised_x = corrupt(x, amount) # 获取模型结构 with torch.no_grad(): preds = net(noised_x.to(device)).detach().cpu() # 画出结果来 fig, axs = plt.subplots(3, 1, figsize=(12, 7)) axs[0].set_title('Input data') axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap='Greys') axs[1].set_title('Corrupted data') axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap='Greys') axs[2].set_title('Network Predictions') axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap='Greys'); 复制代码

下边我放了是三个运行结果,可以看到倒数第二项开始,恢复结果就不咋地了。但是作为一个简单模型,能获得这样的效果已经狠不戳了,随着我们逐渐优化模型,会获得更好的效果的!

实例1 image.png

实例2 image.png

实例3 image.png

本文正在参加「金石计划」



【本文地址】


今日新闻


推荐新闻


    CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3